mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Feature] Support Pipeline Parallism in torchrun SPMD offline inference for V1 (#17827)
Signed-off-by: Lucia Fang <fanglu@fb.com>
This commit is contained in:
parent
6b31c84aff
commit
3d2779c29a
@ -148,6 +148,8 @@ steps:
|
||||
# test with tp=2 and external_dp=2
|
||||
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with tp=2 and pp=2
|
||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
|
||||
@ -8,6 +8,8 @@ the argument 2 should match the `tensor_parallel_size` below.
|
||||
see `tests/distributed/test_torchrun_example.py` for the unit test.
|
||||
"""
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Create prompts, the same across all ranks
|
||||
@ -27,23 +29,26 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
# all ranks have the same random seed, so that sampling can be
|
||||
# deterministic across ranks.
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
model="meta-llama/Llama-3.1-8B",
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2,
|
||||
distributed_executor_backend="external_launcher",
|
||||
seed=0,
|
||||
max_model_len=32768,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# all ranks will have the same outputs
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}")
|
||||
if dist.get_rank() == 0:
|
||||
print("-" * 50)
|
||||
"""
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}\n")
|
||||
print("-" * 50)
|
||||
"""
|
||||
Further tips:
|
||||
|
||||
1. to communicate control messages across all ranks, use the cpu group,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# unit test for `examples/offline_inference/torchrun_example.py`
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch.distributed as dist
|
||||
@ -25,6 +25,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
# to test if all ranks agree on the same kv cache configuration.
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
|
||||
distributed_executor_backend="external_launcher",
|
||||
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||
swap_space=random.randint(1, 4),
|
||||
|
||||
@ -1695,7 +1695,6 @@ class ParallelConfig:
|
||||
"""Port of the data parallel master."""
|
||||
enable_expert_parallel: bool = False
|
||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
"""Maximum number of parallel loading workers when loading model
|
||||
sequentially in multiple batches. To avoid RAM OOM when using tensor
|
||||
|
||||
@ -265,7 +265,8 @@ class CustomAllreduce:
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
ops.dispose(self._ptr)
|
||||
if ops is not None:
|
||||
ops.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
|
||||
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
|
||||
@ -298,4 +299,5 @@ class CustomAllreduce:
|
||||
rank: Optional[int] = 0) -> None:
|
||||
if rank is None:
|
||||
rank = dist.get_rank(group=group)
|
||||
ops.free_shared_buffer(pointers[rank])
|
||||
if ops is not None:
|
||||
ops.free_shared_buffer(pointers[rank])
|
||||
|
||||
@ -1383,9 +1383,10 @@ class EngineArgs:
|
||||
return False
|
||||
|
||||
if (self.pipeline_parallel_size > 1
|
||||
and self.distributed_executor_backend not in ["ray", "mp"]):
|
||||
and self.distributed_executor_backend
|
||||
not in ("ray", "mp", "external_launcher")):
|
||||
name = "Pipeline Parallelism without Ray distributed executor " \
|
||||
"or multiprocessing executor"
|
||||
"or multiprocessing executor or external launcher"
|
||||
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
|
||||
@ -86,9 +86,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model.
|
||||
"""
|
||||
assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \
|
||||
("ExecutorWithExternalLauncher does not "
|
||||
"support pipeline parallelism.")
|
||||
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
|
||||
("ExecutorWithExternalLauncher needs deterministic "
|
||||
"execution, so it"
|
||||
|
||||
@ -22,7 +22,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, graph_capture, prepare_communication_buffer_for_model)
|
||||
get_pp_group, get_tp_group, graph_capture,
|
||||
prepare_communication_buffer_for_model)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
@ -1162,13 +1163,32 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
# TODO: Support overlapping mirco-batches
|
||||
# https://github.com/vllm-project/vllm/issues/18019
|
||||
broadcast_pp_output = \
|
||||
self.parallel_config.distributed_executor_backend \
|
||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
return hidden_states
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if not broadcast_pp_output:
|
||||
return hidden_states
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
model_output_broadcast_data = {
|
||||
"logits": logits.contiguous(),
|
||||
} if logits is not None else {}
|
||||
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
|
||||
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
|
||||
assert model_output_broadcast_data is not None
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
@ -1186,6 +1206,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
|
||||
@ -275,13 +275,13 @@ class Worker(WorkerBase):
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
if parallel_config.distributed_executor_backend != "external_launcher" \
|
||||
and not get_pp_group().is_last_rank:
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return None
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output if self.is_driver_worker else None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user