[Feature] Support Pipeline Parallism in torchrun SPMD offline inference for V1 (#17827)

Signed-off-by: Lucia Fang <fanglu@fb.com>
This commit is contained in:
Lucia Fang 2025-05-15 22:28:27 -07:00 committed by GitHub
parent 6b31c84aff
commit 3d2779c29a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 55 additions and 27 deletions

View File

@ -148,6 +148,8 @@ steps:
# test with tp=2 and external_dp=2 # test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with tp=2 and pp=2
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp # test with internal dp
- python3 ../examples/offline_inference/data_parallel.py - python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py

View File

@ -8,6 +8,8 @@ the argument 2 should match the `tensor_parallel_size` below.
see `tests/distributed/test_torchrun_example.py` for the unit test. see `tests/distributed/test_torchrun_example.py` for the unit test.
""" """
import torch.distributed as dist
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# Create prompts, the same across all ranks # Create prompts, the same across all ranks
@ -27,23 +29,26 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# all ranks have the same random seed, so that sampling can be # all ranks have the same random seed, so that sampling can be
# deterministic across ranks. # deterministic across ranks.
llm = LLM( llm = LLM(
model="facebook/opt-125m", model="meta-llama/Llama-3.1-8B",
tensor_parallel_size=2, tensor_parallel_size=2,
pipeline_parallel_size=2,
distributed_executor_backend="external_launcher", distributed_executor_backend="external_launcher",
seed=0, max_model_len=32768,
seed=1,
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# all ranks will have the same outputs # all ranks will have the same outputs
print("-" * 50) if dist.get_rank() == 0:
for output in outputs: print("-" * 50)
for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n" print(f"Prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}") f"Generated text: {generated_text!r}\n")
print("-" * 50) print("-" * 50)
""" """
Further tips: Further tips:
1. to communicate control messages across all ranks, use the cpu group, 1. to communicate control messages across all ranks, use the cpu group,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# unit test for `examples/offline_inference/torchrun_example.py` # unit test for `examples/offline_inference/torchrun_example.py`
import os
import random import random
import torch.distributed as dist import torch.distributed as dist
@ -25,6 +25,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# to test if all ranks agree on the same kv cache configuration. # to test if all ranks agree on the same kv cache configuration.
llm = LLM(model="facebook/opt-125m", llm = LLM(model="facebook/opt-125m",
tensor_parallel_size=2, tensor_parallel_size=2,
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
distributed_executor_backend="external_launcher", distributed_executor_backend="external_launcher",
gpu_memory_utilization=random.uniform(0.7, 0.9), gpu_memory_utilization=random.uniform(0.7, 0.9),
swap_space=random.randint(1, 4), swap_space=random.randint(1, 4),

View File

@ -1695,7 +1695,6 @@ class ParallelConfig:
"""Port of the data parallel master.""" """Port of the data parallel master."""
enable_expert_parallel: bool = False enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers.""" """Use expert parallelism instead of tensor parallelism for MoE layers."""
max_parallel_loading_workers: Optional[int] = None max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallel loading workers when loading model """Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor sequentially in multiple batches. To avoid RAM OOM when using tensor

View File

@ -265,6 +265,7 @@ class CustomAllreduce:
def close(self): def close(self):
if not self.disabled and self._ptr: if not self.disabled and self._ptr:
if ops is not None:
ops.dispose(self._ptr) ops.dispose(self._ptr)
self._ptr = 0 self._ptr = 0
self.free_shared_buffer(self.meta_ptrs, rank=self.rank) self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
@ -298,4 +299,5 @@ class CustomAllreduce:
rank: Optional[int] = 0) -> None: rank: Optional[int] = 0) -> None:
if rank is None: if rank is None:
rank = dist.get_rank(group=group) rank = dist.get_rank(group=group)
if ops is not None:
ops.free_shared_buffer(pointers[rank]) ops.free_shared_buffer(pointers[rank])

View File

@ -1383,9 +1383,10 @@ class EngineArgs:
return False return False
if (self.pipeline_parallel_size > 1 if (self.pipeline_parallel_size > 1
and self.distributed_executor_backend not in ["ray", "mp"]): and self.distributed_executor_backend
not in ("ray", "mp", "external_launcher")):
name = "Pipeline Parallelism without Ray distributed executor " \ name = "Pipeline Parallelism without Ray distributed executor " \
"or multiprocessing executor" "or multiprocessing executor or external launcher"
_raise_or_fallback(feature_name=name, recommend_to_remove=False) _raise_or_fallback(feature_name=name, recommend_to_remove=False)
return False return False

View File

@ -86,9 +86,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
def _init_executor(self) -> None: def _init_executor(self) -> None:
"""Initialize the worker and load the model. """Initialize the worker and load the model.
""" """
assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \
("ExecutorWithExternalLauncher does not "
"support pipeline parallelism.")
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \ assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
("ExecutorWithExternalLauncher needs deterministic " ("ExecutorWithExternalLauncher needs deterministic "
"execution, so it" "execution, so it"

View File

@ -22,7 +22,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, graph_capture, prepare_communication_buffer_for_model) get_pp_group, get_tp_group, graph_capture,
prepare_communication_buffer_for_model)
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@ -1162,13 +1163,32 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
hidden_states = model_output hidden_states = model_output
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states. # For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
return hidden_states return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(hidden_states.tensors,
all_gather_group=get_tp_group())
logits = None
else:
sample_hidden_states = hidden_states[logits_indices] sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present # Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None: if scheduler_output.grammar_bitmask is not None:
@ -1186,6 +1206,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# creates a new tensor with separate storage from the original # creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits # logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor. # won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler( sampler_output = self.sampler(
logits=bonus_logits, logits=bonus_logits,

View File

@ -275,13 +275,13 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model(scheduler_output, output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors) intermediate_tensors)
parallel_config = self.vllm_config.parallel_config
if not get_pp_group().is_last_rank: if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors) assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors, get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group()) all_gather_group=get_tp_group())
return None return None
assert isinstance(output, ModelRunnerOutput) assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None return output if self.is_driver_worker else None