diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c4459741712d..461fb6d30c45 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -148,6 +148,8 @@ steps: # test with tp=2 and external_dp=2 - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with tp=2 and pp=2 + - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with internal dp - python3 ../examples/offline_inference/data_parallel.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py diff --git a/examples/offline_inference/torchrun_example.py b/examples/offline_inference/torchrun_example.py index c6d9e6b47e21..bb61a0a29e32 100644 --- a/examples/offline_inference/torchrun_example.py +++ b/examples/offline_inference/torchrun_example.py @@ -8,6 +8,8 @@ the argument 2 should match the `tensor_parallel_size` below. see `tests/distributed/test_torchrun_example.py` for the unit test. """ +import torch.distributed as dist + from vllm import LLM, SamplingParams # Create prompts, the same across all ranks @@ -27,23 +29,26 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # all ranks have the same random seed, so that sampling can be # deterministic across ranks. llm = LLM( - model="facebook/opt-125m", + model="meta-llama/Llama-3.1-8B", tensor_parallel_size=2, + pipeline_parallel_size=2, distributed_executor_backend="external_launcher", - seed=0, + max_model_len=32768, + seed=1, ) outputs = llm.generate(prompts, sampling_params) # all ranks will have the same outputs -print("-" * 50) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}") +if dist.get_rank() == 0: print("-" * 50) -""" + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\n" + f"Generated text: {generated_text!r}\n") + print("-" * 50) + """ Further tips: 1. to communicate control messages across all ranks, use the cpu group, diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index 0420a6454d46..bb38e908b734 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # unit test for `examples/offline_inference/torchrun_example.py` - +import os import random import torch.distributed as dist @@ -25,6 +25,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # to test if all ranks agree on the same kv cache configuration. llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, + pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), distributed_executor_backend="external_launcher", gpu_memory_utilization=random.uniform(0.7, 0.9), swap_space=random.randint(1, 4), diff --git a/vllm/config.py b/vllm/config.py index dddfdabd126a..d07a1ff05234 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1695,7 +1695,6 @@ class ParallelConfig: """Port of the data parallel master.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" - max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model sequentially in multiple batches. To avoid RAM OOM when using tensor diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 7567161b6ac7..5c2dbcc27b13 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -265,7 +265,8 @@ class CustomAllreduce: def close(self): if not self.disabled and self._ptr: - ops.dispose(self._ptr) + if ops is not None: + ops.dispose(self._ptr) self._ptr = 0 self.free_shared_buffer(self.meta_ptrs, rank=self.rank) self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) @@ -298,4 +299,5 @@ class CustomAllreduce: rank: Optional[int] = 0) -> None: if rank is None: rank = dist.get_rank(group=group) - ops.free_shared_buffer(pointers[rank]) + if ops is not None: + ops.free_shared_buffer(pointers[rank]) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6fdb5e6c3772..dc2bb3a52cac 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1383,9 +1383,10 @@ class EngineArgs: return False if (self.pipeline_parallel_size > 1 - and self.distributed_executor_backend not in ["ray", "mp"]): + and self.distributed_executor_backend + not in ("ray", "mp", "external_launcher")): name = "Pipeline Parallelism without Ray distributed executor " \ - "or multiprocessing executor" + "or multiprocessing executor or external launcher" _raise_or_fallback(feature_name=name, recommend_to_remove=False) return False diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 2e4b47c1e24a..1d3a6e443a80 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -86,9 +86,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor): def _init_executor(self) -> None: """Initialize the worker and load the model. """ - assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \ - ("ExecutorWithExternalLauncher does not " - "support pipeline parallelism.") assert self.vllm_config.scheduler_config.delay_factor == 0.0, \ ("ExecutorWithExternalLauncher needs deterministic " "execution, so it" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0788ac5adde8..cb802fd4f102 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -22,7 +22,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( - get_pp_group, graph_capture, prepare_communication_buffer_for_model) + get_pp_group, get_tp_group, graph_capture, + prepare_communication_buffer_for_model) from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -1162,13 +1163,32 @@ class GPUModelRunner(LoRAModelRunnerMixin): hidden_states, aux_hidden_states = model_output else: hidden_states = model_output - + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. - return hidden_states - - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + if not broadcast_pp_output: + return hidden_states + assert isinstance(hidden_states, IntermediateTensors) + get_pp_group().send_tensor_dict(hidden_states.tensors, + all_gather_group=get_tp_group()) + logits = None + else: + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: + model_output_broadcast_data = { + "logits": logits.contiguous(), + } if logits is not None else {} + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: @@ -1186,6 +1206,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # creates a new tensor with separate storage from the original # logits tensor. This means any in-place operations on bonus_logits # won't affect the original logits tensor. + assert logits is not None bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] sampler_output = self.sampler( logits=bonus_logits, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d85701fa93df..93129d987940 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -275,13 +275,13 @@ class Worker(WorkerBase): output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - - if not get_pp_group().is_last_rank: + parallel_config = self.vllm_config.parallel_config + if parallel_config.distributed_executor_backend != "external_launcher" \ + and not get_pp_group().is_last_rank: assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) return None - assert isinstance(output, ModelRunnerOutput) return output if self.is_driver_worker else None