mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:55:00 +08:00
[V1] Support MP Executor for multi node distributed inference (#23691)
Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
a55b64635c
commit
b316ac6589
437
tests/distributed/test_multiproc_executor.py
Normal file
437
tests/distributed/test_multiproc_executor.py
Normal file
@ -0,0 +1,437 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""
|
||||||
|
Integration tests for MultiprocExecutor at the executor level.
|
||||||
|
This test directly tests the executor without going through the LLM interface,
|
||||||
|
focusing on executor initialization, RPC calls, and distributed execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tests.utils import multi_gpu_test
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.utils import get_open_port
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||||
|
|
||||||
|
MODEL = "facebook/opt-125m"
|
||||||
|
|
||||||
|
|
||||||
|
def create_vllm_config(
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
|
pipeline_parallel_size: int = 1,
|
||||||
|
max_model_len: int = 256,
|
||||||
|
gpu_memory_utilization: float = 0.3,
|
||||||
|
distributed_executor_backend: str = "mp",
|
||||||
|
nnodes: int = 1,
|
||||||
|
node_rank: int = 0,
|
||||||
|
master_port: int = 0,
|
||||||
|
) -> VllmConfig:
|
||||||
|
"""Create a VllmConfig for testing using EngineArgs."""
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=MODEL,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
pipeline_parallel_size=pipeline_parallel_size,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
vllm_config = engine_args.create_engine_config()
|
||||||
|
|
||||||
|
# Override distributed node settings if needed
|
||||||
|
if nnodes > 1 or node_rank > 0:
|
||||||
|
vllm_config.parallel_config.nnodes = nnodes
|
||||||
|
vllm_config.parallel_config.node_rank = node_rank
|
||||||
|
vllm_config.parallel_config.master_port = master_port
|
||||||
|
if nnodes > 1:
|
||||||
|
vllm_config.parallel_config.disable_custom_all_reduce = True
|
||||||
|
|
||||||
|
return vllm_config
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_scheduler_output(num_requests: int = 1) -> SchedulerOutput:
|
||||||
|
"""Create a minimal SchedulerOutput for testing."""
|
||||||
|
# This is a simplified version - in practice you'd need proper
|
||||||
|
# SchedulerOutput construction based on the actual vLLM v1 API
|
||||||
|
return SchedulerOutput(
|
||||||
|
scheduled_new_reqs=[],
|
||||||
|
scheduled_resumed_reqs=[],
|
||||||
|
scheduled_running_reqs=[],
|
||||||
|
num_scheduled_tokens={},
|
||||||
|
total_num_scheduled_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiproc_executor_initialization():
|
||||||
|
"""Test that MultiprocExecutor can be initialized with proper config."""
|
||||||
|
vllm_config = create_vllm_config(
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create executor - this should initialize workers
|
||||||
|
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
# Verify executor properties
|
||||||
|
assert executor.world_size == 1, "World size should be 1 for single GPU"
|
||||||
|
assert executor.local_world_size == 1, "Local world size should be 1"
|
||||||
|
assert hasattr(executor, "workers"), "Executor should have workers"
|
||||||
|
assert len(executor.workers) == 1, "Should have 1 worker for single GPU"
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
def test_multiproc_executor_initialization_tensor_parallel():
|
||||||
|
"""Test MultiprocExecutor initialization with tensor parallelism."""
|
||||||
|
vllm_config = create_vllm_config(
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create executor
|
||||||
|
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
# Verify executor properties
|
||||||
|
assert executor.world_size == 2, "World size should be 2 for TP=2"
|
||||||
|
assert executor.local_world_size == 2, "Local world size should be 2"
|
||||||
|
assert len(executor.workers) == 2, "Should have 2 workers for TP=2"
|
||||||
|
|
||||||
|
# Verify output rank calculation
|
||||||
|
output_rank = executor._get_output_rank()
|
||||||
|
assert output_rank == 0, "Output rank should be 0 for TP=2, PP=1"
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
def test_multiproc_executor_collective_rpc():
|
||||||
|
"""Test collective RPC calls to all workers."""
|
||||||
|
vllm_config = create_vllm_config(
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create executor
|
||||||
|
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test check_health RPC - should work without errors
|
||||||
|
executor.check_health()
|
||||||
|
|
||||||
|
# Test that RPC works correctly
|
||||||
|
# Note: We're just testing that the RPC mechanism works,
|
||||||
|
# not testing actual model execution here
|
||||||
|
assert not executor.is_failed, "Executor should not be in failed state"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiproc_executor_failure_callback():
|
||||||
|
"""Test failure callback registration and invocation."""
|
||||||
|
vllm_config = create_vllm_config(
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test callback registration
|
||||||
|
callback_invoked = []
|
||||||
|
|
||||||
|
def test_callback():
|
||||||
|
callback_invoked.append(True)
|
||||||
|
|
||||||
|
# Register callback
|
||||||
|
executor.register_failure_callback(test_callback)
|
||||||
|
|
||||||
|
# Callback should not be invoked yet
|
||||||
|
assert len(callback_invoked) == 0, "Callback should not be invoked immediately"
|
||||||
|
|
||||||
|
# Simulate failure
|
||||||
|
executor.is_failed = True
|
||||||
|
|
||||||
|
# Register another callback - should be invoked immediately
|
||||||
|
executor.register_failure_callback(test_callback)
|
||||||
|
assert len(callback_invoked) == 1, (
|
||||||
|
"Callback should be invoked when executor is failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
def test_multiproc_executor_worker_monitor():
|
||||||
|
"""Test that worker monitor is set up correctly."""
|
||||||
|
vllm_config = create_vllm_config(
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify all worker processes are alive
|
||||||
|
for worker in executor.workers:
|
||||||
|
assert worker.proc.is_alive(), f"Worker rank {worker.rank} should be alive"
|
||||||
|
|
||||||
|
# Verify executor is not in failed state
|
||||||
|
assert not executor.is_failed, "Executor should not be in failed state"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
# After shutdown, workers should be terminated
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(0.5) # Give processes time to terminate
|
||||||
|
for worker in executor.workers:
|
||||||
|
assert not worker.proc.is_alive(), (
|
||||||
|
f"Worker rank {worker.rank} should terminate after shutdown"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
def test_multiproc_executor_get_response_message_queues():
|
||||||
|
"""Test message queue retrieval for different ranks."""
|
||||||
|
vllm_config = create_vllm_config(
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get all message queues
|
||||||
|
all_queues = executor.get_response_mqs()
|
||||||
|
assert len(all_queues) == 2, "Should have 2 message queues for 2 workers"
|
||||||
|
|
||||||
|
# Get message queue for specific rank
|
||||||
|
rank0_queue = executor.get_response_mqs(unique_reply_rank=0)
|
||||||
|
assert len(rank0_queue) == 1, "Should have 1 message queue for rank 0"
|
||||||
|
|
||||||
|
rank1_queue = executor.get_response_mqs(unique_reply_rank=1)
|
||||||
|
assert len(rank1_queue) == 1, "Should have 1 message queue for rank 1"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiproc_executor_shutdown_cleanup():
|
||||||
|
"""Test that shutdown properly cleans up resources."""
|
||||||
|
vllm_config = create_vllm_config(
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
# Verify executor is set up
|
||||||
|
assert hasattr(executor, "workers"), "Executor should have workers"
|
||||||
|
assert len(executor.workers) > 0, "Should have at least one worker"
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
# Verify cleanup
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(0.5) # Give processes time to terminate
|
||||||
|
|
||||||
|
for worker in executor.workers:
|
||||||
|
assert not worker.proc.is_alive(), "Worker processes should be terminated"
|
||||||
|
|
||||||
|
# Verify shutdown event is set
|
||||||
|
assert executor.shutdown_event.is_set(), "Shutdown event should be set"
|
||||||
|
|
||||||
|
# Multiple shutdowns should be safe (idempotent)
|
||||||
|
executor.shutdown()
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=4)
|
||||||
|
def test_multiproc_executor_pipeline_parallel():
|
||||||
|
"""Test MultiprocExecutor with pipeline parallelism."""
|
||||||
|
vllm_config = create_vllm_config(
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
pipeline_parallel_size=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify executor properties
|
||||||
|
assert executor.world_size == 4, "World size should be 4 for TP=2, PP=2"
|
||||||
|
assert len(executor.workers) == 4, "Should have 4 workers"
|
||||||
|
|
||||||
|
# Verify output rank calculation
|
||||||
|
# For TP=2, PP=2: output should be from the last PP stage (ranks 2-3)
|
||||||
|
# Specifically rank 2 (first rank of last PP stage)
|
||||||
|
output_rank = executor._get_output_rank()
|
||||||
|
assert output_rank == 2, "Output rank should be 2 (first rank of last PP stage)"
|
||||||
|
|
||||||
|
# Verify max_concurrent_batches for pipeline parallel
|
||||||
|
assert executor.max_concurrent_batches == 2, (
|
||||||
|
"Max concurrent batches should equal PP size"
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiproc_executor_properties():
|
||||||
|
"""Test various executor properties and configurations."""
|
||||||
|
vllm_config = create_vllm_config(
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test supports_pp property
|
||||||
|
assert MultiprocExecutor.supports_pp is True, (
|
||||||
|
"MultiprocExecutor should support pipeline parallelism"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test world_size calculation
|
||||||
|
assert executor.world_size == (
|
||||||
|
executor.parallel_config.tensor_parallel_size
|
||||||
|
* executor.parallel_config.pipeline_parallel_size
|
||||||
|
), "World size should equal TP * PP"
|
||||||
|
|
||||||
|
# Test local_world_size calculation
|
||||||
|
assert executor.local_world_size == (
|
||||||
|
executor.parallel_config.world_size // executor.parallel_config.nnodes
|
||||||
|
), "Local world size should be world_size / nnodes"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=4)
|
||||||
|
def test_multiproc_executor_multi_node():
|
||||||
|
"""
|
||||||
|
Test MultiprocExecutor with multi-node configuration.
|
||||||
|
This simulates 2 nodes with TP=4:
|
||||||
|
- Node 0 (rank 0): Uses GPUs 0,1 (CUDA_VISIBLE_DEVICES=0,1) with TP=2
|
||||||
|
- Node 1 (rank 1): Uses GPUs 2,3 (CUDA_VISIBLE_DEVICES=2,3) with TP=2
|
||||||
|
Total world_size = 4, nnodes = 2
|
||||||
|
"""
|
||||||
|
port = get_open_port()
|
||||||
|
# symm_mem does not work for simulating multi instance in single node
|
||||||
|
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||||
|
|
||||||
|
def run_node(node_rank: int, result_queue: multiprocessing.Queue, port: int):
|
||||||
|
"""Run a single node's executor."""
|
||||||
|
executor = None
|
||||||
|
try:
|
||||||
|
# Set CUDA_VISIBLE_DEVICES for this node
|
||||||
|
if node_rank == 0:
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||||
|
else:
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
|
||||||
|
|
||||||
|
# Create config for this node
|
||||||
|
vllm_config = create_vllm_config(
|
||||||
|
tensor_parallel_size=4, # Total TP across all nodes
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
nnodes=2, # 2 nodes
|
||||||
|
node_rank=node_rank,
|
||||||
|
master_port=port, # same port
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create executor for this node
|
||||||
|
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
# Verify node-specific properties
|
||||||
|
assert executor.world_size == 4, (
|
||||||
|
f"World size should be 4 on node {node_rank}"
|
||||||
|
)
|
||||||
|
assert executor.local_world_size == 2, (
|
||||||
|
f"Local world size should be 2 on node {node_rank}"
|
||||||
|
)
|
||||||
|
assert len(executor.workers) == 2, (
|
||||||
|
f"Should have 2 local workers on node {node_rank}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify worker ranks are correct for this node
|
||||||
|
expected_ranks = [node_rank * 2, node_rank * 2 + 1]
|
||||||
|
actual_ranks = sorted([w.rank for w in executor.workers])
|
||||||
|
assert actual_ranks == expected_ranks, (
|
||||||
|
f"Node {node_rank} should have workers "
|
||||||
|
f"with ranks {expected_ranks}, got {actual_ranks}"
|
||||||
|
)
|
||||||
|
# Verify all workers are alive
|
||||||
|
for worker in executor.workers:
|
||||||
|
assert worker.proc.is_alive(), (
|
||||||
|
f"Worker rank {worker.rank} should be alive on node {node_rank}"
|
||||||
|
)
|
||||||
|
# executor.gen
|
||||||
|
# Put success result in queue BEFORE shutdown to avoid hanging
|
||||||
|
result_queue.put({"node": node_rank, "success": True})
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
executor.shutdown()
|
||||||
|
except Exception as e:
|
||||||
|
# Put failure result in queue
|
||||||
|
result_queue.put({"node": node_rank, "success": False, "error": str(e)})
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
if executor is not None:
|
||||||
|
executor.shutdown()
|
||||||
|
|
||||||
|
# Create a queue to collect results from both processes
|
||||||
|
result_queue: multiprocessing.Queue[dict[str, int | bool]] = multiprocessing.Queue()
|
||||||
|
|
||||||
|
# Start both node processes
|
||||||
|
processes = []
|
||||||
|
for node_rank in range(2):
|
||||||
|
p = multiprocessing.Process(
|
||||||
|
target=run_node,
|
||||||
|
args=(node_rank, result_queue, port),
|
||||||
|
name=f"Node{node_rank}",
|
||||||
|
)
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
|
||||||
|
# Wait for both processes to complete
|
||||||
|
all_completed = True
|
||||||
|
for p in processes:
|
||||||
|
p.join(timeout=60)
|
||||||
|
if p.is_alive():
|
||||||
|
p.terminate()
|
||||||
|
p.join(timeout=20)
|
||||||
|
if p.is_alive():
|
||||||
|
p.kill()
|
||||||
|
p.join()
|
||||||
|
all_completed = False
|
||||||
|
|
||||||
|
# Check results from both nodes
|
||||||
|
results: list[dict[str, int | bool]] = []
|
||||||
|
while len(results) < 2:
|
||||||
|
try:
|
||||||
|
result = result_queue.get(timeout=1)
|
||||||
|
results.append(result)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
assert all_completed, "Not all processes completed successfully"
|
||||||
|
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
|
||||||
|
assert results[0]["success"], f"Node 0 failed: {results[0]}"
|
||||||
|
assert results[1]["success"], f"Node 1 failed: {results[1]}"
|
||||||
@ -210,6 +210,18 @@ class ParallelConfig:
|
|||||||
class is dynamically inherited by the worker class. This is used to inject
|
class is dynamically inherited by the worker class. This is used to inject
|
||||||
new attributes and methods to the worker class for use in collective_rpc
|
new attributes and methods to the worker class for use in collective_rpc
|
||||||
calls."""
|
calls."""
|
||||||
|
master_addr: str = "127.0.0.1"
|
||||||
|
"""distributed master address for multi-node distributed
|
||||||
|
inference when distributed_executor_backend is mp."""
|
||||||
|
master_port: int = 29501
|
||||||
|
"""distributed master port for multi-node distributed
|
||||||
|
inference when distributed_executor_backend is mp."""
|
||||||
|
node_rank: int = 0
|
||||||
|
"""distributed node rank for multi-node distributed
|
||||||
|
inference when distributed_executor_backend is mp."""
|
||||||
|
nnodes: int = 1
|
||||||
|
"""num of nodes for multi-node distributed
|
||||||
|
inference when distributed_executor_backend is mp."""
|
||||||
|
|
||||||
world_size: int = Field(init=False)
|
world_size: int = Field(init=False)
|
||||||
"""world_size is TPxPP, it affects the number of workers we create."""
|
"""world_size is TPxPP, it affects the number of workers we create."""
|
||||||
@ -387,6 +399,23 @@ class ParallelConfig:
|
|||||||
and self.data_parallel_size > 1
|
and self.data_parallel_size > 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_rank_within_dp(self) -> int:
|
||||||
|
return self.node_rank % self.nnodes_within_dp
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nnodes_within_dp(self) -> int:
|
||||||
|
if self.nnodes == 1:
|
||||||
|
return 1
|
||||||
|
data_parallel_node_size = (
|
||||||
|
self.data_parallel_size // self.data_parallel_size_local
|
||||||
|
)
|
||||||
|
return self.nnodes // data_parallel_node_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def local_world_size(self) -> int:
|
||||||
|
return self.world_size // self.nnodes_within_dp
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
|
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
|
||||||
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
|
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
|
||||||
@ -528,6 +557,8 @@ class ParallelConfig:
|
|||||||
ray_found = ray_utils.ray_is_available()
|
ray_found = ray_utils.ray_is_available()
|
||||||
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
|
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
|
||||||
backend = "uni"
|
backend = "uni"
|
||||||
|
elif current_platform.is_cuda() and self.nnodes > 1:
|
||||||
|
backend = "mp"
|
||||||
elif (
|
elif (
|
||||||
current_platform.is_cuda()
|
current_platform.is_cuda()
|
||||||
and cuda_device_count_stateless() < self.world_size
|
and cuda_device_count_stateless() < self.world_size
|
||||||
@ -565,6 +596,10 @@ class ParallelConfig:
|
|||||||
"max_parallel_loading_workers is currently "
|
"max_parallel_loading_workers is currently "
|
||||||
"not supported and will be ignored."
|
"not supported and will be ignored."
|
||||||
)
|
)
|
||||||
|
if self.distributed_executor_backend != "mp" and self.nnodes > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"nnodes > 1 can only be set when distributed exectuor backend is mp."
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_ray(self) -> bool:
|
def use_ray(self) -> bool:
|
||||||
@ -607,6 +642,11 @@ class ParallelConfig:
|
|||||||
"Disabled the custom all-reduce kernel because it is not "
|
"Disabled the custom all-reduce kernel because it is not "
|
||||||
"supported on current platform."
|
"supported on current platform."
|
||||||
)
|
)
|
||||||
|
if self.nnodes > 1:
|
||||||
|
self.disable_custom_all_reduce = True
|
||||||
|
logger.debug(
|
||||||
|
"Disabled the custom all-reduce since we are running on multi-node."
|
||||||
|
)
|
||||||
if self.ray_workers_use_nsight and not self.use_ray:
|
if self.ray_workers_use_nsight and not self.use_ray:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unable to use nsight profiling unless workers run with Ray."
|
"Unable to use nsight profiling unless workers run with Ray."
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from dataclasses import dataclass, field
|
|||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from pickle import PickleBuffer
|
from pickle import PickleBuffer
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -602,13 +602,87 @@ class MessageQueue:
|
|||||||
return obj
|
return obj
|
||||||
return self.dequeue()
|
return self.dequeue()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_from_process_group_single_reader(
|
||||||
|
pg: ProcessGroup,
|
||||||
|
max_chunk_bytes,
|
||||||
|
max_chunks,
|
||||||
|
reader_rank: int = 0,
|
||||||
|
blocking: bool = False,
|
||||||
|
) -> tuple["MessageQueue", list[Handle]]:
|
||||||
|
"""
|
||||||
|
Creates a MessageQueue for a process group with a single reader.
|
||||||
|
|
||||||
|
This method is designed for scenarios where only one process (the reader)
|
||||||
|
will consume messages, and all other processes are writers. It sets up
|
||||||
|
the shared memory buffer and communication handles accordingly, and
|
||||||
|
gathers the handles from all processes to the reader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pg (ProcessGroup): The torch distributed process group.
|
||||||
|
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
|
||||||
|
max_chunks (int): Maximum number of chunks in the buffer.
|
||||||
|
reader_rank (int, optional): The global rank that will act as the reader.
|
||||||
|
Defaults to 0.
|
||||||
|
blocking (bool, optional): If True, blocks until all processes are ready.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[MessageQueue, list[Handle]]:
|
||||||
|
The MessageQueue instance for the calling process,
|
||||||
|
and a list of handles (only non-empty for the reader process).
|
||||||
|
"""
|
||||||
|
local_size = torch.cuda.device_count()
|
||||||
|
rank = dist.get_rank()
|
||||||
|
same_node = rank // local_size == reader_rank // local_size
|
||||||
|
buffer_io = MessageQueue(
|
||||||
|
n_reader=1,
|
||||||
|
n_local_reader=1 if same_node else 0,
|
||||||
|
max_chunk_bytes=max_chunk_bytes,
|
||||||
|
max_chunks=max_chunks,
|
||||||
|
)
|
||||||
|
handle = buffer_io.export_handle()
|
||||||
|
handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
|
||||||
|
dist.gather_object(handle, handles, dst=reader_rank, group=pg)
|
||||||
|
if blocking:
|
||||||
|
buffer_io.wait_until_ready()
|
||||||
|
return buffer_io, cast(list[Handle], handles or [])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_from_process_group(
|
def create_from_process_group(
|
||||||
pg: ProcessGroup | StatelessProcessGroup,
|
pg: ProcessGroup | StatelessProcessGroup,
|
||||||
max_chunk_bytes,
|
max_chunk_bytes,
|
||||||
max_chunks,
|
max_chunks,
|
||||||
writer_rank=0,
|
writer_rank: int = 0,
|
||||||
|
external_writer_handle=None,
|
||||||
|
blocking: bool = True,
|
||||||
) -> "MessageQueue":
|
) -> "MessageQueue":
|
||||||
|
"""
|
||||||
|
Creates a MessageQueue for a distributed process group with one writer and
|
||||||
|
multiple readers.
|
||||||
|
|
||||||
|
This method is designed for scenarios where one process (the writer) sends
|
||||||
|
messages, and all other processes (the readers) receive messages. It sets up
|
||||||
|
the shared memory buffer and socket communication handles accordingly, and
|
||||||
|
broadcasts the handle from the writer to all readers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
|
||||||
|
group.
|
||||||
|
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
|
||||||
|
max_chunks (int): Maximum number of chunks in the buffer.
|
||||||
|
writer_rank (int, optional): The global rank that will act as the writer.
|
||||||
|
Defaults to 0.
|
||||||
|
external_writer_handle (Handle, optional): Used when there is a handle
|
||||||
|
from an external Message Queue. If provided, use this handle to init
|
||||||
|
PG writer message queue instead of creating a new one. Defaults to None.
|
||||||
|
blocking (bool, optional): If True, blocks until all processes are ready.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MessageQueue: The MessageQueue instance for the calling process.
|
||||||
|
|
||||||
|
"""
|
||||||
if isinstance(pg, ProcessGroup):
|
if isinstance(pg, ProcessGroup):
|
||||||
group_rank = dist.get_rank(pg)
|
group_rank = dist.get_rank(pg)
|
||||||
group_world_size = dist.get_world_size(pg)
|
group_world_size = dist.get_world_size(pg)
|
||||||
@ -617,23 +691,26 @@ class MessageQueue:
|
|||||||
group_rank = pg.rank
|
group_rank = pg.rank
|
||||||
group_world_size = pg.world_size
|
group_world_size = pg.world_size
|
||||||
global_ranks = list(range(pg.world_size))
|
global_ranks = list(range(pg.world_size))
|
||||||
|
|
||||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||||
|
|
||||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
|
||||||
n_reader = group_world_size - 1
|
|
||||||
n_local_reader = len(same_node_ranks) - 1
|
|
||||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
|
||||||
buffer_io: MessageQueue
|
|
||||||
if group_rank == writer_rank:
|
if group_rank == writer_rank:
|
||||||
buffer_io = MessageQueue(
|
if external_writer_handle is not None:
|
||||||
n_reader=n_reader,
|
buffer_io = MessageQueue.create_from_handle(
|
||||||
n_local_reader=n_local_reader,
|
external_writer_handle, group_rank
|
||||||
local_reader_ranks=local_reader_ranks,
|
)
|
||||||
max_chunk_bytes=max_chunk_bytes,
|
else:
|
||||||
max_chunks=max_chunks,
|
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||||
)
|
n_reader = group_world_size - 1
|
||||||
|
n_local_reader = len(same_node_ranks) - 1
|
||||||
|
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||||
|
buffer_io = MessageQueue(
|
||||||
|
n_reader=n_reader,
|
||||||
|
n_local_reader=n_local_reader,
|
||||||
|
local_reader_ranks=local_reader_ranks,
|
||||||
|
max_chunk_bytes=max_chunk_bytes,
|
||||||
|
max_chunks=max_chunks,
|
||||||
|
)
|
||||||
handle = buffer_io.export_handle()
|
handle = buffer_io.export_handle()
|
||||||
if isinstance(pg, ProcessGroup):
|
if isinstance(pg, ProcessGroup):
|
||||||
dist.broadcast_object_list(
|
dist.broadcast_object_list(
|
||||||
@ -651,5 +728,6 @@ class MessageQueue:
|
|||||||
else:
|
else:
|
||||||
handle = pg.broadcast_obj(None, writer_rank)
|
handle = pg.broadcast_obj(None, writer_rank)
|
||||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||||
buffer_io.wait_until_ready()
|
if blocking:
|
||||||
|
buffer_io.wait_until_ready()
|
||||||
return buffer_io
|
return buffer_io
|
||||||
|
|||||||
@ -385,6 +385,33 @@ class GroupCoordinator:
|
|||||||
torch.ops._C, "init_shm_manager"
|
torch.ops._C, "init_shm_manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_mq_broadcaster(
|
||||||
|
self, writer_rank=0, external_writer_handle=None, blocking=True
|
||||||
|
):
|
||||||
|
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||||
|
|
||||||
|
return MessageQueue.create_from_process_group(
|
||||||
|
self.cpu_group,
|
||||||
|
1 << 22,
|
||||||
|
6,
|
||||||
|
writer_rank=writer_rank,
|
||||||
|
external_writer_handle=external_writer_handle,
|
||||||
|
blocking=blocking,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_single_reader_mq_broadcasters(
|
||||||
|
self, reader_rank_in_group=0, blocking=False
|
||||||
|
):
|
||||||
|
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||||
|
|
||||||
|
return MessageQueue.create_from_process_group_single_reader(
|
||||||
|
self.cpu_group,
|
||||||
|
1 << 22,
|
||||||
|
6,
|
||||||
|
reader_rank=self.ranks[reader_rank_in_group],
|
||||||
|
blocking=blocking,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def first_rank(self):
|
def first_rank(self):
|
||||||
"""Return the global rank of the first process in the group"""
|
"""Return the global rank of the first process in the group"""
|
||||||
@ -997,6 +1024,7 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
|
|
||||||
_WORLD: GroupCoordinator | None = None
|
_WORLD: GroupCoordinator | None = None
|
||||||
|
_INNER_DP_WORLD: GroupCoordinator | None = None
|
||||||
_NODE_COUNT: int | None = None
|
_NODE_COUNT: int | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -1005,6 +1033,11 @@ def get_world_group() -> GroupCoordinator:
|
|||||||
return _WORLD
|
return _WORLD
|
||||||
|
|
||||||
|
|
||||||
|
def get_inner_dp_world_group() -> GroupCoordinator:
|
||||||
|
assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
|
||||||
|
return _INNER_DP_WORLD
|
||||||
|
|
||||||
|
|
||||||
def init_world_group(
|
def init_world_group(
|
||||||
ranks: list[int], local_rank: int, backend: str
|
ranks: list[int], local_rank: int, backend: str
|
||||||
) -> GroupCoordinator:
|
) -> GroupCoordinator:
|
||||||
@ -1023,12 +1056,13 @@ def init_model_parallel_group(
|
|||||||
backend: str,
|
backend: str,
|
||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
group_name: str | None = None,
|
group_name: str | None = None,
|
||||||
|
use_device_communicator: bool = True,
|
||||||
) -> GroupCoordinator:
|
) -> GroupCoordinator:
|
||||||
return GroupCoordinator(
|
return GroupCoordinator(
|
||||||
group_ranks=group_ranks,
|
group_ranks=group_ranks,
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
torch_distributed_backend=backend,
|
torch_distributed_backend=backend,
|
||||||
use_device_communicator=True,
|
use_device_communicator=use_device_communicator,
|
||||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||||
group_name=group_name,
|
group_name=group_name,
|
||||||
)
|
)
|
||||||
@ -1143,7 +1177,14 @@ def init_distributed_environment(
|
|||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
config = get_current_vllm_config()
|
config = get_current_vllm_config()
|
||||||
if (
|
if config is not None and config.parallel_config.nnodes > 1:
|
||||||
|
parallel_config = config.parallel_config
|
||||||
|
ip = parallel_config.master_addr
|
||||||
|
rank = parallel_config.data_parallel_rank * world_size + rank
|
||||||
|
world_size = parallel_config.world_size_across_dp
|
||||||
|
port = parallel_config.master_port
|
||||||
|
distributed_init_method = get_distributed_init_method(ip, port)
|
||||||
|
elif (
|
||||||
config is not None
|
config is not None
|
||||||
and config.parallel_config.data_parallel_size > 1
|
and config.parallel_config.data_parallel_size > 1
|
||||||
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
||||||
@ -1164,6 +1205,14 @@ def init_distributed_environment(
|
|||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
)
|
)
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
|
logger.info(
|
||||||
|
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
local_rank,
|
||||||
|
distributed_init_method,
|
||||||
|
backend,
|
||||||
|
)
|
||||||
assert distributed_init_method is not None, (
|
assert distributed_init_method is not None, (
|
||||||
"distributed_init_method must be provided when initializing "
|
"distributed_init_method must be provided when initializing "
|
||||||
"distributed environment"
|
"distributed environment"
|
||||||
@ -1192,16 +1241,36 @@ def init_distributed_environment(
|
|||||||
# local rank not set, this usually happens in single-node
|
# local rank not set, this usually happens in single-node
|
||||||
# setting, where we can use rank as local rank
|
# setting, where we can use rank as local rank
|
||||||
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
|
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
|
||||||
global _WORLD, _NODE_COUNT
|
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
|
||||||
if _WORLD is None:
|
if _WORLD is None:
|
||||||
ranks = list(range(torch.distributed.get_world_size()))
|
ranks = list(range(torch.distributed.get_world_size()))
|
||||||
_WORLD = init_world_group(ranks, local_rank, backend)
|
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||||
_NODE_COUNT = _node_count(_WORLD.cpu_group)
|
if config.parallel_config.nnodes > 1:
|
||||||
|
_NODE_COUNT = config.parallel_config.nnodes
|
||||||
|
else:
|
||||||
|
_NODE_COUNT = _node_count(_WORLD.cpu_group)
|
||||||
logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
|
logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
|
||||||
else:
|
else:
|
||||||
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
||||||
"world group already initialized with a different world size"
|
"world group already initialized with a different world size"
|
||||||
)
|
)
|
||||||
|
if config.parallel_config.nnodes_within_dp > 1:
|
||||||
|
if parallel_config.data_parallel_size > 1:
|
||||||
|
world_size_inner_dp = parallel_config.world_size
|
||||||
|
group_ranks = [
|
||||||
|
[dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
|
||||||
|
for dp_rank in range(parallel_config.data_parallel_size)
|
||||||
|
]
|
||||||
|
_INNER_DP_WORLD = init_model_parallel_group(
|
||||||
|
group_ranks,
|
||||||
|
get_world_group().local_rank,
|
||||||
|
backend,
|
||||||
|
use_message_queue_broadcaster=True,
|
||||||
|
group_name="inner_dp_world",
|
||||||
|
use_device_communicator=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_INNER_DP_WORLD = _WORLD
|
||||||
|
|
||||||
|
|
||||||
def initialize_model_parallel(
|
def initialize_model_parallel(
|
||||||
|
|||||||
@ -384,6 +384,10 @@ class EngineArgs:
|
|||||||
) = ParallelConfig.distributed_executor_backend
|
) = ParallelConfig.distributed_executor_backend
|
||||||
# number of P/D disaggregation (or other disaggregation) workers
|
# number of P/D disaggregation (or other disaggregation) workers
|
||||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||||
|
master_addr: str = ParallelConfig.master_addr
|
||||||
|
master_port: int = ParallelConfig.master_port
|
||||||
|
nnodes: int = ParallelConfig.nnodes
|
||||||
|
node_rank: int = ParallelConfig.node_rank
|
||||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||||
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
||||||
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
||||||
@ -394,6 +398,7 @@ class EngineArgs:
|
|||||||
data_parallel_address: str | None = None
|
data_parallel_address: str | None = None
|
||||||
data_parallel_rpc_port: int | None = None
|
data_parallel_rpc_port: int | None = None
|
||||||
data_parallel_hybrid_lb: bool = False
|
data_parallel_hybrid_lb: bool = False
|
||||||
|
data_parallel_external_lb: bool = False
|
||||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||||
all2all_backend: str | None = ParallelConfig.all2all_backend
|
all2all_backend: str | None = ParallelConfig.all2all_backend
|
||||||
@ -749,6 +754,10 @@ class EngineArgs:
|
|||||||
"-pp",
|
"-pp",
|
||||||
**parallel_kwargs["pipeline_parallel_size"],
|
**parallel_kwargs["pipeline_parallel_size"],
|
||||||
)
|
)
|
||||||
|
parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
|
||||||
|
parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
|
||||||
|
parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
|
||||||
|
parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
|
"--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
|
||||||
)
|
)
|
||||||
@ -803,7 +812,14 @@ class EngineArgs:
|
|||||||
help='Backend for data parallel, either "mp" or "ray".',
|
help='Backend for data parallel, either "mp" or "ray".',
|
||||||
)
|
)
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
|
"--data-parallel-hybrid-lb",
|
||||||
|
"-dph",
|
||||||
|
**parallel_kwargs["data_parallel_hybrid_lb"],
|
||||||
|
)
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--data-parallel-external-lb",
|
||||||
|
"-dpe",
|
||||||
|
**parallel_kwargs["data_parallel_external_lb"],
|
||||||
)
|
)
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
|
"--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
|
||||||
@ -1428,12 +1444,56 @@ class EngineArgs:
|
|||||||
assert not headless or not self.data_parallel_hybrid_lb, (
|
assert not headless or not self.data_parallel_hybrid_lb, (
|
||||||
"data_parallel_hybrid_lb is not applicable in headless mode"
|
"data_parallel_hybrid_lb is not applicable in headless mode"
|
||||||
)
|
)
|
||||||
|
assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
|
||||||
data_parallel_external_lb = self.data_parallel_rank is not None
|
"data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
|
||||||
|
)
|
||||||
|
assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
|
||||||
|
"nnodes > 1 is only supported with data_parallel_backend=mp"
|
||||||
|
)
|
||||||
|
inferred_data_parallel_rank = 0
|
||||||
|
if self.nnodes > 1:
|
||||||
|
world_size = (
|
||||||
|
self.data_parallel_size
|
||||||
|
* self.pipeline_parallel_size
|
||||||
|
* self.tensor_parallel_size
|
||||||
|
)
|
||||||
|
world_size_within_dp = (
|
||||||
|
self.pipeline_parallel_size * self.tensor_parallel_size
|
||||||
|
)
|
||||||
|
local_world_size = world_size // self.nnodes
|
||||||
|
assert world_size % self.nnodes == 0, (
|
||||||
|
f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
|
||||||
|
)
|
||||||
|
assert self.node_rank < self.nnodes, (
|
||||||
|
f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
|
||||||
|
)
|
||||||
|
inferred_data_parallel_rank = (
|
||||||
|
self.node_rank * local_world_size
|
||||||
|
) // world_size_within_dp
|
||||||
|
if self.data_parallel_size > 1 and self.data_parallel_external_lb:
|
||||||
|
self.data_parallel_rank = inferred_data_parallel_rank
|
||||||
|
logger.info(
|
||||||
|
"Inferred data_parallel_rank %d from node_rank %d for external lb",
|
||||||
|
self.data_parallel_rank,
|
||||||
|
self.node_rank,
|
||||||
|
)
|
||||||
|
elif self.data_parallel_size_local is None:
|
||||||
|
# Infer data parallel size local for internal dplb:
|
||||||
|
self.data_parallel_size_local = max(
|
||||||
|
local_world_size // world_size_within_dp, 1
|
||||||
|
)
|
||||||
|
data_parallel_external_lb = (
|
||||||
|
self.data_parallel_external_lb or self.data_parallel_rank is not None
|
||||||
|
)
|
||||||
# Local DP rank = 1, use pure-external LB.
|
# Local DP rank = 1, use pure-external LB.
|
||||||
if data_parallel_external_lb:
|
if data_parallel_external_lb:
|
||||||
|
assert self.data_parallel_rank is not None, (
|
||||||
|
"data_parallel_rank or node_rank must be spefified if "
|
||||||
|
"data_parallel_external_lb is enable."
|
||||||
|
)
|
||||||
assert self.data_parallel_size_local in (1, None), (
|
assert self.data_parallel_size_local in (1, None), (
|
||||||
"data_parallel_size_local must be 1 when data_parallel_rank is set"
|
"data_parallel_size_local must be 1 or None when data_parallel_rank "
|
||||||
|
"is set"
|
||||||
)
|
)
|
||||||
data_parallel_size_local = 1
|
data_parallel_size_local = 1
|
||||||
# Use full external lb if we have local_size of 1.
|
# Use full external lb if we have local_size of 1.
|
||||||
@ -1447,6 +1507,11 @@ class EngineArgs:
|
|||||||
|
|
||||||
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
|
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
|
||||||
# Use full external lb if we have local_size of 1.
|
# Use full external lb if we have local_size of 1.
|
||||||
|
logger.warning(
|
||||||
|
"data_parallel_hybrid_lb is not eligible when "
|
||||||
|
"data_parallel_size_local = 1, autoswitch to "
|
||||||
|
"data_parallel_external_lb."
|
||||||
|
)
|
||||||
data_parallel_external_lb = True
|
data_parallel_external_lb = True
|
||||||
self.data_parallel_hybrid_lb = False
|
self.data_parallel_hybrid_lb = False
|
||||||
|
|
||||||
@ -1454,7 +1519,15 @@ class EngineArgs:
|
|||||||
# Disable hybrid LB mode if set for a single node
|
# Disable hybrid LB mode if set for a single node
|
||||||
self.data_parallel_hybrid_lb = False
|
self.data_parallel_hybrid_lb = False
|
||||||
|
|
||||||
self.data_parallel_rank = self.data_parallel_start_rank or 0
|
self.data_parallel_rank = (
|
||||||
|
self.data_parallel_start_rank or inferred_data_parallel_rank
|
||||||
|
)
|
||||||
|
if self.nnodes > 1:
|
||||||
|
logger.info(
|
||||||
|
"Inferred data_parallel_rank %d from node_rank %d",
|
||||||
|
self.data_parallel_rank,
|
||||||
|
self.node_rank,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert not self.data_parallel_hybrid_lb, (
|
assert not self.data_parallel_hybrid_lb, (
|
||||||
"data_parallel_size_local must be set to use data_parallel_hybrid_lb."
|
"data_parallel_size_local must be set to use data_parallel_hybrid_lb."
|
||||||
@ -1484,7 +1557,9 @@ class EngineArgs:
|
|||||||
"data_parallel_backend can only be ray or mp, got %s",
|
"data_parallel_backend can only be ray or mp, got %s",
|
||||||
self.data_parallel_backend,
|
self.data_parallel_backend,
|
||||||
)
|
)
|
||||||
data_parallel_address = ParallelConfig.data_parallel_master_ip
|
data_parallel_address = (
|
||||||
|
self.master_addr or ParallelConfig.data_parallel_master_ip
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
data_parallel_address = self.data_parallel_address
|
data_parallel_address = self.data_parallel_address
|
||||||
|
|
||||||
@ -1517,6 +1592,10 @@ class EngineArgs:
|
|||||||
data_parallel_rank=self.data_parallel_rank or 0,
|
data_parallel_rank=self.data_parallel_rank or 0,
|
||||||
data_parallel_external_lb=data_parallel_external_lb,
|
data_parallel_external_lb=data_parallel_external_lb,
|
||||||
data_parallel_size_local=data_parallel_size_local,
|
data_parallel_size_local=data_parallel_size_local,
|
||||||
|
master_addr=self.master_addr,
|
||||||
|
master_port=self.master_port,
|
||||||
|
nnodes=self.nnodes,
|
||||||
|
node_rank=self.node_rank,
|
||||||
data_parallel_master_ip=data_parallel_address,
|
data_parallel_master_ip=data_parallel_address,
|
||||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||||
data_parallel_backend=self.data_parallel_backend,
|
data_parallel_backend=self.data_parallel_backend,
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from vllm.utils.system_utils import decorate_logs, set_process_title
|
|||||||
from vllm.v1.engine.core import EngineCoreProc
|
from vllm.v1.engine.core import EngineCoreProc
|
||||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||||
from vllm.v1.executor import Executor
|
from vllm.v1.executor import Executor
|
||||||
|
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||||
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
|
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
|
||||||
|
|
||||||
@ -97,18 +98,40 @@ def run_headless(args: argparse.Namespace):
|
|||||||
if local_engine_count <= 0:
|
if local_engine_count <= 0:
|
||||||
raise ValueError("data_parallel_size_local must be > 0 in headless mode")
|
raise ValueError("data_parallel_size_local must be > 0 in headless mode")
|
||||||
|
|
||||||
host = parallel_config.data_parallel_master_ip
|
shutdown_requested = False
|
||||||
port = engine_args.data_parallel_rpc_port # add to config too
|
|
||||||
handshake_address = get_tcp_uri(host, port)
|
|
||||||
|
|
||||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||||
def signal_handler(signum, frame):
|
def signal_handler(signum, frame):
|
||||||
|
nonlocal shutdown_requested
|
||||||
logger.debug("Received %d signal.", signum)
|
logger.debug("Received %d signal.", signum)
|
||||||
raise SystemExit
|
if not shutdown_requested:
|
||||||
|
shutdown_requested = True
|
||||||
|
raise SystemExit
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
if parallel_config.node_rank_within_dp > 0:
|
||||||
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
|
# Run headless workers (for multi-node PP/TP).
|
||||||
|
host = parallel_config.master_addr
|
||||||
|
head_node_address = f"{host}:{parallel_config.master_port}"
|
||||||
|
logger.info(
|
||||||
|
"Launching vLLM (v%s) headless multiproc executor, "
|
||||||
|
"with head node address %s for torch.distributed process group.",
|
||||||
|
VLLM_VERSION,
|
||||||
|
head_node_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = MultiprocExecutor(vllm_config, monitor_workers=False)
|
||||||
|
executor.start_worker_monitor(inline=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
host = parallel_config.data_parallel_master_ip
|
||||||
|
port = parallel_config.data_parallel_rpc_port
|
||||||
|
handshake_address = get_tcp_uri(host, port)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Launching %d data parallel engine(s) in headless mode, "
|
"Launching %d data parallel engine(s) in headless mode, "
|
||||||
"with head node address %s.",
|
"with head node address %s.",
|
||||||
|
|||||||
@ -183,15 +183,19 @@ def set_device_control_env_var(
|
|||||||
for engine subprocess.
|
for engine subprocess.
|
||||||
"""
|
"""
|
||||||
world_size = vllm_config.parallel_config.world_size
|
world_size = vllm_config.parallel_config.world_size
|
||||||
|
local_world_size = vllm_config.parallel_config.local_world_size
|
||||||
evar = current_platform.device_control_env_var
|
evar = current_platform.device_control_env_var
|
||||||
|
|
||||||
value = get_device_indices(evar, local_dp_rank, world_size)
|
value = get_device_indices(evar, local_dp_rank, world_size, local_world_size)
|
||||||
with patch.dict(os.environ, values=((evar, value),)):
|
with patch.dict(os.environ, values=((evar, value),)):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
def get_device_indices(
|
def get_device_indices(
|
||||||
device_control_env_var: str, local_dp_rank: int, world_size: int
|
device_control_env_var: str,
|
||||||
|
local_dp_rank: int,
|
||||||
|
world_size: int,
|
||||||
|
local_world_size: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Returns a comma-separated string of device indices for the specified
|
Returns a comma-separated string of device indices for the specified
|
||||||
@ -200,10 +204,15 @@ def get_device_indices(
|
|||||||
For example, if world_size=2 and local_dp_rank=1, and there are 4 devices,
|
For example, if world_size=2 and local_dp_rank=1, and there are 4 devices,
|
||||||
this will select devices 2 and 3 for local_dp_rank=1.
|
this will select devices 2 and 3 for local_dp_rank=1.
|
||||||
"""
|
"""
|
||||||
|
if local_world_size is None:
|
||||||
|
local_world_size = world_size
|
||||||
try:
|
try:
|
||||||
value = ",".join(
|
value = ",".join(
|
||||||
str(current_platform.device_id_to_physical_device_id(i))
|
str(current_platform.device_id_to_physical_device_id(i))
|
||||||
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)
|
for i in range(
|
||||||
|
local_dp_rank * world_size,
|
||||||
|
local_dp_rank * world_size + local_world_size,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except IndexError as e:
|
except IndexError as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import weakref
|
import weakref
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Sequence
|
||||||
from concurrent.futures import Future, InvalidStateError
|
from concurrent.futures import Future, InvalidStateError
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -34,6 +34,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
get_dcp_group,
|
get_dcp_group,
|
||||||
get_dp_group,
|
get_dp_group,
|
||||||
get_ep_group,
|
get_ep_group,
|
||||||
|
get_inner_dp_world_group,
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
)
|
)
|
||||||
@ -90,6 +91,10 @@ class FutureWrapper(Future):
|
|||||||
class MultiprocExecutor(Executor):
|
class MultiprocExecutor(Executor):
|
||||||
supports_pp: bool = True
|
supports_pp: bool = True
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig, monitor_workers: bool = True):
|
||||||
|
self.monitor_workers = monitor_workers
|
||||||
|
super().__init__(vllm_config)
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
# Call self.shutdown at exit to clean up
|
# Call self.shutdown at exit to clean up
|
||||||
# and ensure workers will be terminated.
|
# and ensure workers will be terminated.
|
||||||
@ -99,6 +104,12 @@ class MultiprocExecutor(Executor):
|
|||||||
self.failure_callback: FailureCallback | None = None
|
self.failure_callback: FailureCallback | None = None
|
||||||
|
|
||||||
self.world_size = self.parallel_config.world_size
|
self.world_size = self.parallel_config.world_size
|
||||||
|
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
|
||||||
|
f"global world_size ({self.parallel_config.world_size}) must be "
|
||||||
|
f"divisible by nnodes_within_dp "
|
||||||
|
f"({self.parallel_config.nnodes_within_dp}). "
|
||||||
|
)
|
||||||
|
self.local_world_size = self.parallel_config.local_world_size
|
||||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||||
@ -116,27 +127,37 @@ class MultiprocExecutor(Executor):
|
|||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method = get_distributed_init_method(
|
||||||
get_loopback_ip(), get_open_port()
|
get_loopback_ip(), get_open_port()
|
||||||
)
|
)
|
||||||
|
self.rpc_broadcast_mq: MessageQueue | None = None
|
||||||
|
scheduler_output_handle: Handle | None = None
|
||||||
# Initialize worker and set up message queues for SchedulerOutputs
|
# Initialize worker and set up message queues for SchedulerOutputs
|
||||||
# and ModelRunnerOutputs
|
# and ModelRunnerOutputs
|
||||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
if self.parallel_config.node_rank_within_dp == 0:
|
||||||
self.rpc_broadcast_mq = MessageQueue(
|
# For leader node within each dp rank,
|
||||||
self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes
|
# each dp will have its own leader multiproc executor.
|
||||||
)
|
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
self.rpc_broadcast_mq = MessageQueue(
|
||||||
|
self.world_size,
|
||||||
|
self.local_world_size,
|
||||||
|
max_chunk_bytes=max_chunk_bytes,
|
||||||
|
connect_ip=self.parallel_config.master_addr,
|
||||||
|
)
|
||||||
|
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||||
# Create workers
|
# Create workers
|
||||||
context = get_mp_context()
|
context = get_mp_context()
|
||||||
shared_worker_lock = context.Lock()
|
shared_worker_lock = context.Lock()
|
||||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||||
success = False
|
success = False
|
||||||
try:
|
try:
|
||||||
for rank in range(self.world_size):
|
global_start_rank = (
|
||||||
|
self.local_world_size * self.parallel_config.node_rank_within_dp
|
||||||
|
)
|
||||||
|
for local_rank in range(self.local_world_size):
|
||||||
|
global_rank = global_start_rank + local_rank
|
||||||
unready_workers.append(
|
unready_workers.append(
|
||||||
WorkerProc.make_worker_process(
|
WorkerProc.make_worker_process(
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
local_rank=rank,
|
local_rank=local_rank,
|
||||||
rank=rank,
|
rank=global_rank,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
input_shm_handle=scheduler_output_handle,
|
input_shm_handle=scheduler_output_handle,
|
||||||
shared_worker_lock=shared_worker_lock,
|
shared_worker_lock=shared_worker_lock,
|
||||||
@ -145,15 +166,38 @@ class MultiprocExecutor(Executor):
|
|||||||
|
|
||||||
# Workers must be created before wait_for_ready to avoid
|
# Workers must be created before wait_for_ready to avoid
|
||||||
# deadlock, since worker.init_device() does a device sync.
|
# deadlock, since worker.init_device() does a device sync.
|
||||||
|
|
||||||
|
# Wait for all local workers to be ready.
|
||||||
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
||||||
|
|
||||||
|
# Start background thread to monitor worker health if not in headless mode.
|
||||||
|
if self.monitor_workers:
|
||||||
|
self.start_worker_monitor()
|
||||||
|
|
||||||
|
self.response_mqs = []
|
||||||
|
# Only leader node have remote response mqs
|
||||||
|
if self.parallel_config.node_rank_within_dp == 0:
|
||||||
|
for rank in range(self.world_size):
|
||||||
|
if rank < self.local_world_size:
|
||||||
|
local_message_queue = self.workers[rank].worker_response_mq
|
||||||
|
assert local_message_queue is not None
|
||||||
|
self.response_mqs.append(local_message_queue)
|
||||||
|
else:
|
||||||
|
remote_message_queue = self.workers[0].peer_worker_response_mqs[
|
||||||
|
rank
|
||||||
|
]
|
||||||
|
assert remote_message_queue is not None
|
||||||
|
self.response_mqs.append(remote_message_queue)
|
||||||
|
|
||||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||||
# Must be kept consistent with the WorkerProc.
|
# Must be kept consistent with the WorkerProc.
|
||||||
self.rpc_broadcast_mq.wait_until_ready()
|
|
||||||
for w in self.workers:
|
|
||||||
w.worker_response_mq.wait_until_ready()
|
|
||||||
|
|
||||||
self.start_worker_monitor()
|
# Wait for all input mqs to be ready.
|
||||||
|
if self.rpc_broadcast_mq is not None:
|
||||||
|
self.rpc_broadcast_mq.wait_until_ready()
|
||||||
|
# Wait for all remote response mqs to be ready.
|
||||||
|
for response_mq in self.response_mqs:
|
||||||
|
response_mq.wait_until_ready()
|
||||||
success = True
|
success = True
|
||||||
finally:
|
finally:
|
||||||
if not success:
|
if not success:
|
||||||
@ -168,7 +212,7 @@ class MultiprocExecutor(Executor):
|
|||||||
|
|
||||||
self.output_rank = self._get_output_rank()
|
self.output_rank = self._get_output_rank()
|
||||||
|
|
||||||
def start_worker_monitor(self):
|
def start_worker_monitor(self, inline=False) -> None:
|
||||||
workers = self.workers
|
workers = self.workers
|
||||||
self_ref = weakref.ref(self)
|
self_ref = weakref.ref(self)
|
||||||
|
|
||||||
@ -192,9 +236,13 @@ class MultiprocExecutor(Executor):
|
|||||||
_self.failure_callback = None
|
_self.failure_callback = None
|
||||||
callback()
|
callback()
|
||||||
|
|
||||||
Thread(
|
if not inline:
|
||||||
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
|
Thread(
|
||||||
).start()
|
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
|
||||||
|
).start()
|
||||||
|
return
|
||||||
|
|
||||||
|
monitor_workers()
|
||||||
|
|
||||||
def register_failure_callback(self, callback: FailureCallback):
|
def register_failure_callback(self, callback: FailureCallback):
|
||||||
if self.is_failed:
|
if self.is_failed:
|
||||||
@ -247,7 +295,9 @@ class MultiprocExecutor(Executor):
|
|||||||
) -> Any | list[Any] | Future[Any | list[Any]]:
|
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||||
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
|
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
|
||||||
is provided, otherwise list."""
|
is provided, otherwise list."""
|
||||||
|
assert self.rpc_broadcast_mq is not None, (
|
||||||
|
"collective_rpc should not be called on follower node"
|
||||||
|
)
|
||||||
if self.is_failed:
|
if self.is_failed:
|
||||||
raise RuntimeError("Executor failed.")
|
raise RuntimeError("Executor failed.")
|
||||||
|
|
||||||
@ -269,20 +319,20 @@ class MultiprocExecutor(Executor):
|
|||||||
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
|
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
|
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
|
||||||
|
|
||||||
workers = (
|
response_mqs: Sequence[MessageQueue] = self.response_mqs
|
||||||
(self.workers[output_rank],) if output_rank is not None else self.workers
|
if output_rank is not None:
|
||||||
)
|
response_mqs = (response_mqs[output_rank],)
|
||||||
|
|
||||||
shutdown_event = self.shutdown_event
|
shutdown_event = self.shutdown_event
|
||||||
|
|
||||||
def get_response():
|
def get_response():
|
||||||
responses = []
|
responses = []
|
||||||
for w in workers:
|
for mq in response_mqs:
|
||||||
dequeue_timeout = (
|
dequeue_timeout = (
|
||||||
None if deadline is None else (deadline - time.monotonic())
|
None if deadline is None else (deadline - time.monotonic())
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
status, result = w.worker_response_mq.dequeue(
|
status, result = mq.dequeue(
|
||||||
timeout=dequeue_timeout, cancel=shutdown_event
|
timeout=dequeue_timeout, cancel=shutdown_event
|
||||||
)
|
)
|
||||||
except TimeoutError as e:
|
except TimeoutError as e:
|
||||||
@ -391,17 +441,26 @@ class UnreadyWorkerProcHandle:
|
|||||||
class WorkerProcHandle:
|
class WorkerProcHandle:
|
||||||
proc: BaseProcess
|
proc: BaseProcess
|
||||||
rank: int
|
rank: int
|
||||||
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
# The worker process writes to this MQ in single-node mode
|
||||||
|
worker_response_mq: MessageQueue | None
|
||||||
|
# This is only non empty on driver node,
|
||||||
|
# the peer worker process i writes to MQ
|
||||||
|
# `peer_worker_response_mqs[i]`
|
||||||
|
peer_worker_response_mqs: list[MessageQueue | None]
|
||||||
death_writer: Connection | None = None
|
death_writer: Connection | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_unready_handle(
|
def from_unready_handle(
|
||||||
cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue
|
cls,
|
||||||
|
unready_handle: UnreadyWorkerProcHandle,
|
||||||
|
worker_response_mq: MessageQueue | None,
|
||||||
|
peer_worker_response_mqs: list[MessageQueue | None],
|
||||||
) -> "WorkerProcHandle":
|
) -> "WorkerProcHandle":
|
||||||
return cls(
|
return cls(
|
||||||
proc=unready_handle.proc,
|
proc=unready_handle.proc,
|
||||||
rank=unready_handle.rank,
|
rank=unready_handle.rank,
|
||||||
worker_response_mq=worker_response_mq,
|
worker_response_mq=worker_response_mq,
|
||||||
|
peer_worker_response_mqs=peer_worker_response_mqs,
|
||||||
death_writer=unready_handle.death_writer,
|
death_writer=unready_handle.death_writer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -411,6 +470,38 @@ class WorkerProc:
|
|||||||
|
|
||||||
READY_STR = "READY"
|
READY_STR = "READY"
|
||||||
|
|
||||||
|
def _init_message_queues(
|
||||||
|
self, input_shm_handle: Handle, vllm_config: VllmConfig
|
||||||
|
) -> None:
|
||||||
|
if vllm_config.parallel_config.nnodes_within_dp == 1:
|
||||||
|
# Initialize MessageQueue for receiving SchedulerOutput
|
||||||
|
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||||
|
input_shm_handle, self.worker.rank
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initializes a message queue for sending the model output
|
||||||
|
self.worker_response_mq: MessageQueue = MessageQueue(1, 1)
|
||||||
|
self.peer_response_handles = []
|
||||||
|
else:
|
||||||
|
# Initialize remote MessageQueue for receiving SchedulerOutput across nodes
|
||||||
|
self.rpc_broadcast_mq = get_inner_dp_world_group().create_mq_broadcaster(
|
||||||
|
external_writer_handle=input_shm_handle,
|
||||||
|
# Since there is external_writer_handle from executor proc,
|
||||||
|
# where the ready signal from actual writer is sent out of the
|
||||||
|
# create_mq_broadcaster method and after this setup, we make it
|
||||||
|
# non blocking. The handshake will be triggered when
|
||||||
|
# worker.rpc_broadcast_mq.wait_until_ready() is called
|
||||||
|
blocking=False,
|
||||||
|
)
|
||||||
|
# Initializes remote message queue for sending the model output to the
|
||||||
|
# driver worker, exposing peer_response_handles for driver worker
|
||||||
|
# that include handles for all ranks
|
||||||
|
self.worker_response_mq, self.peer_response_handles = (
|
||||||
|
get_inner_dp_world_group().create_single_reader_mq_broadcasters(
|
||||||
|
reader_rank_in_group=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
@ -421,13 +512,15 @@ class WorkerProc:
|
|||||||
shared_worker_lock: LockType,
|
shared_worker_lock: LockType,
|
||||||
):
|
):
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
wrapper = WorkerWrapperBase(
|
||||||
|
vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
|
||||||
|
)
|
||||||
# TODO: move `init_worker` to executor level as a collective rpc call
|
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||||
all_kwargs: list[dict] = [
|
all_kwargs: list[dict] = [
|
||||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||||
]
|
]
|
||||||
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
|
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
|
||||||
all_kwargs[rank] = {
|
all_kwargs[local_rank] = {
|
||||||
"vllm_config": vllm_config,
|
"vllm_config": vllm_config,
|
||||||
"local_rank": local_rank,
|
"local_rank": local_rank,
|
||||||
"rank": rank,
|
"rank": rank,
|
||||||
@ -438,14 +531,6 @@ class WorkerProc:
|
|||||||
wrapper.init_worker(all_kwargs)
|
wrapper.init_worker(all_kwargs)
|
||||||
self.worker = wrapper
|
self.worker = wrapper
|
||||||
|
|
||||||
# Initialize MessageQueue for receiving SchedulerOutput
|
|
||||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
|
||||||
input_shm_handle, self.worker.rank
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initializes a message queue for sending the model output
|
|
||||||
self.worker_response_mq = MessageQueue(1, 1)
|
|
||||||
|
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
self.use_async_scheduling = scheduler_config.async_scheduling
|
self.use_async_scheduling = scheduler_config.async_scheduling
|
||||||
if self.use_async_scheduling:
|
if self.use_async_scheduling:
|
||||||
@ -466,6 +551,7 @@ class WorkerProc:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
|
self._init_message_queues(input_shm_handle, vllm_config)
|
||||||
self.worker.load_model()
|
self.worker.load_model()
|
||||||
|
|
||||||
# Enable environment variable cache (e.g. assume no more
|
# Enable environment variable cache (e.g. assume no more
|
||||||
@ -512,6 +598,27 @@ class WorkerProc:
|
|||||||
# death_reader in child will get EOFError
|
# death_reader in child will get EOFError
|
||||||
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
|
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def wait_for_response_handle_ready(
|
||||||
|
handles: dict[str, Any], proc_handle: UnreadyWorkerProcHandle
|
||||||
|
) -> WorkerProcHandle:
|
||||||
|
response_handle = handles["handle"]
|
||||||
|
worker_response_mq: MessageQueue | None = None
|
||||||
|
if len(response_handle.local_reader_ranks) > 0:
|
||||||
|
worker_response_mq = MessageQueue.create_from_handle(response_handle, 0)
|
||||||
|
peer_response_handles = handles["peer_response_handles"]
|
||||||
|
peer_worker_response_mqs = [
|
||||||
|
MessageQueue.create_from_handle(handle, -1)
|
||||||
|
if handle.remote_subscribe_addr is not None
|
||||||
|
else None
|
||||||
|
for handle in peer_response_handles
|
||||||
|
]
|
||||||
|
return WorkerProcHandle.from_unready_handle(
|
||||||
|
proc_handle,
|
||||||
|
worker_response_mq,
|
||||||
|
peer_worker_response_mqs=peer_worker_response_mqs,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait_for_ready(
|
def wait_for_ready(
|
||||||
unready_proc_handles: list[UnreadyWorkerProcHandle],
|
unready_proc_handles: list[UnreadyWorkerProcHandle],
|
||||||
@ -537,16 +644,10 @@ class WorkerProc:
|
|||||||
if response["status"] != "READY":
|
if response["status"] != "READY":
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# Extract the message queue handle.
|
idx = unready_proc_handle.rank % len(ready_proc_handles)
|
||||||
worker_response_mq = MessageQueue.create_from_handle(
|
ready_proc_handles[idx] = WorkerProc.wait_for_response_handle_ready(
|
||||||
response["handle"], 0
|
response, unready_proc_handle
|
||||||
)
|
)
|
||||||
ready_proc_handles[unready_proc_handle.rank] = (
|
|
||||||
WorkerProcHandle.from_unready_handle(
|
|
||||||
unready_proc_handle, worker_response_mq
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
except EOFError:
|
except EOFError:
|
||||||
e.__suppress_context__ = True
|
e.__suppress_context__ = True
|
||||||
raise e from None
|
raise e from None
|
||||||
@ -618,12 +719,14 @@ class WorkerProc:
|
|||||||
{
|
{
|
||||||
"status": WorkerProc.READY_STR,
|
"status": WorkerProc.READY_STR,
|
||||||
"handle": worker.worker_response_mq.export_handle(),
|
"handle": worker.worker_response_mq.export_handle(),
|
||||||
|
"peer_response_handles": worker.peer_response_handles,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure message queues are ready. Will deadlock if re-ordered.
|
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||||
# Must be kept consistent with the Executor
|
# Must be kept consistent with the Executor
|
||||||
worker.rpc_broadcast_mq.wait_until_ready()
|
if worker.rpc_broadcast_mq is not None:
|
||||||
|
worker.rpc_broadcast_mq.wait_until_ready()
|
||||||
worker.worker_response_mq.wait_until_ready()
|
worker.worker_response_mq.wait_until_ready()
|
||||||
ready_writer.close()
|
ready_writer.close()
|
||||||
ready_writer = None
|
ready_writer = None
|
||||||
|
|||||||
@ -189,6 +189,7 @@ class Worker(WorkerBase):
|
|||||||
and self.parallel_config.distributed_executor_backend
|
and self.parallel_config.distributed_executor_backend
|
||||||
not in ["ray", "external_launcher"]
|
not in ["ray", "external_launcher"]
|
||||||
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
|
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
|
||||||
|
and self.vllm_config.parallel_config.nnodes_within_dp == 1
|
||||||
):
|
):
|
||||||
# Use local DP rank if available, otherwise use global DP rank.
|
# Use local DP rank if available, otherwise use global DP rank.
|
||||||
dp_local_rank = self.parallel_config.data_parallel_rank_local
|
dp_local_rank = self.parallel_config.data_parallel_rank_local
|
||||||
@ -205,7 +206,14 @@ class Worker(WorkerBase):
|
|||||||
assert self.local_rank < torch.cuda.device_count(), (
|
assert self.local_rank < torch.cuda.device_count(), (
|
||||||
f"DP adjusted local rank {self.local_rank} is out of bounds. "
|
f"DP adjusted local rank {self.local_rank} is out of bounds. "
|
||||||
)
|
)
|
||||||
|
visible_device_count = (
|
||||||
|
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||||||
|
)
|
||||||
|
assert self.parallel_config.local_world_size <= visible_device_count, (
|
||||||
|
f"local_world_size ({self.parallel_config.local_world_size}) must be "
|
||||||
|
f"less than or equal to the number of visible devices "
|
||||||
|
f"({visible_device_count})."
|
||||||
|
)
|
||||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||||
current_platform.set_device(self.device)
|
current_platform.set_device(self.device)
|
||||||
|
|
||||||
|
|||||||
@ -180,6 +180,7 @@ class WorkerWrapperBase:
|
|||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
rpc_rank: int = 0,
|
rpc_rank: int = 0,
|
||||||
|
global_rank: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the worker wrapper with the given vllm_config and rpc_rank.
|
Initialize the worker wrapper with the given vllm_config and rpc_rank.
|
||||||
@ -192,6 +193,7 @@ class WorkerWrapperBase:
|
|||||||
group.
|
group.
|
||||||
"""
|
"""
|
||||||
self.rpc_rank = rpc_rank
|
self.rpc_rank = rpc_rank
|
||||||
|
self.global_rank = self.rpc_rank if global_rank is None else global_rank
|
||||||
self.worker: WorkerBase | None = None
|
self.worker: WorkerBase | None = None
|
||||||
|
|
||||||
# do not store this `vllm_config`, `init_worker` will set the final
|
# do not store this `vllm_config`, `init_worker` will set the final
|
||||||
@ -312,7 +314,7 @@ class WorkerWrapperBase:
|
|||||||
assert self.worker is not None
|
assert self.worker is not None
|
||||||
|
|
||||||
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
|
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
|
||||||
kv_cache_config = kv_cache_configs[self.rpc_rank]
|
kv_cache_config = kv_cache_configs[self.global_rank]
|
||||||
with set_current_vllm_config(self.vllm_config):
|
with set_current_vllm_config(self.vllm_config):
|
||||||
self.worker.initialize_from_config(kv_cache_config) # type: ignore
|
self.worker.initialize_from_config(kv_cache_config) # type: ignore
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user