[Feature] Support Pipeline Parallism in torchrun SPMD offline inference for V1 (#17827)

Signed-off-by: Lucia Fang <fanglu@fb.com>
This commit is contained in:
Lucia Fang 2025-05-15 22:28:27 -07:00 committed by GitHub
parent 6b31c84aff
commit 3d2779c29a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 55 additions and 27 deletions

View File

@ -148,6 +148,8 @@ steps:
# test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with tp=2 and pp=2
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py

View File

@ -8,6 +8,8 @@ the argument 2 should match the `tensor_parallel_size` below.
see `tests/distributed/test_torchrun_example.py` for the unit test.
"""
import torch.distributed as dist
from vllm import LLM, SamplingParams
# Create prompts, the same across all ranks
@ -27,23 +29,26 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm = LLM(
model="facebook/opt-125m",
model="meta-llama/Llama-3.1-8B",
tensor_parallel_size=2,
pipeline_parallel_size=2,
distributed_executor_backend="external_launcher",
seed=0,
max_model_len=32768,
seed=1,
)
outputs = llm.generate(prompts, sampling_params)
# all ranks will have the same outputs
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}")
if dist.get_rank() == 0:
print("-" * 50)
"""
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}\n")
print("-" * 50)
"""
Further tips:
1. to communicate control messages across all ranks, use the cpu group,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# unit test for `examples/offline_inference/torchrun_example.py`
import os
import random
import torch.distributed as dist
@ -25,6 +25,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# to test if all ranks agree on the same kv cache configuration.
llm = LLM(model="facebook/opt-125m",
tensor_parallel_size=2,
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
distributed_executor_backend="external_launcher",
gpu_memory_utilization=random.uniform(0.7, 0.9),
swap_space=random.randint(1, 4),

View File

@ -1695,7 +1695,6 @@ class ParallelConfig:
"""Port of the data parallel master."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor

View File

@ -265,7 +265,8 @@ class CustomAllreduce:
def close(self):
if not self.disabled and self._ptr:
ops.dispose(self._ptr)
if ops is not None:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
@ -298,4 +299,5 @@ class CustomAllreduce:
rank: Optional[int] = 0) -> None:
if rank is None:
rank = dist.get_rank(group=group)
ops.free_shared_buffer(pointers[rank])
if ops is not None:
ops.free_shared_buffer(pointers[rank])

View File

@ -1383,9 +1383,10 @@ class EngineArgs:
return False
if (self.pipeline_parallel_size > 1
and self.distributed_executor_backend not in ["ray", "mp"]):
and self.distributed_executor_backend
not in ("ray", "mp", "external_launcher")):
name = "Pipeline Parallelism without Ray distributed executor " \
"or multiprocessing executor"
"or multiprocessing executor or external launcher"
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
return False

View File

@ -86,9 +86,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \
("ExecutorWithExternalLauncher does not "
"support pipeline parallelism.")
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
("ExecutorWithExternalLauncher needs deterministic "
"execution, so it"

View File

@ -22,7 +22,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
get_pp_group, graph_capture, prepare_communication_buffer_for_model)
get_pp_group, get_tp_group, graph_capture,
prepare_communication_buffer_for_model)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@ -1162,13 +1163,32 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
return hidden_states
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if not broadcast_pp_output:
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(hidden_states.tensors,
all_gather_group=get_tp_group())
logits = None
else:
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
@ -1186,6 +1206,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,

View File

@ -275,13 +275,13 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if not get_pp_group().is_last_rank:
parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None