diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 65c5e6896844..ded3d834faf0 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -235,7 +235,6 @@ def _compare_sp( 'level': 3, 'custom_ops': ["+rms_norm"], 'compile_sizes': [4, 8], - 'splitting_ops': [], 'pass_config': { 'enable_sequence_parallelism': True, 'enable_fusion': enable_fusion, @@ -251,6 +250,8 @@ def _compare_sp( *common_args, "--tensor-parallel-size", str(tp_size), + "--pipeline-parallel-size", + str(pp_size), "--distributed-executor-backend", distributed_backend, "--compilation_config", diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ef229299b684..12571afaa4c1 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -663,14 +663,29 @@ class GroupCoordinator: tensor_dict: dict[str, Union[torch.Tensor, Any]], dst: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: Optional[dict[str, bool]] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. + + all_gather_group: The group for the all-gather operation. If provided, + an optimization is enabled where each rank in the group sends a + slice of a tensor and the receiver reconstructs it using an + all-gather, which can improve performance. This is typically the + tensor-parallel group. + all_gather_tensors: A dictionary to specify which tensors should use + the all-gather optimization, which is only effective when + `all_gather_group` is provided. By default, this optimization is + on for any tensor whose size is divisible by the + `all_gather_group`'s world size. However, it should be disabled + for tensors that are not fully replicated across the group (e.g., + the residual tensor when sequence parallelism is enabled). This + dictionary allows overriding the default behavior on a per-tensor + basis. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict - all_gather_size = (1 if all_gather_group is None else all_gather_group.world_size) all_gather_rank = (0 if all_gather_group is None else @@ -699,14 +714,23 @@ class GroupCoordinator: # `send_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. self.send_object(metadata_list, dst=dst) - for tensor in tensor_list: + + tensor_keys = [ + k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor) + ] + assert len(tensor_keys) == len(tensor_list) + + for key, tensor in zip(tensor_keys, tensor_list): if tensor.numel() == 0: # Skip sending empty tensors. continue # send-allgather: send only a slice, then do allgather. - if (all_gather_group is not None - and tensor.numel() % all_gather_size == 0): + use_all_gather = (all_gather_group is not None + and tensor.numel() % all_gather_size == 0) + use_all_gather = all_gather_tensors.get(key, use_all_gather) \ + if all_gather_tensors else use_all_gather + if use_all_gather: tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: @@ -725,14 +749,29 @@ class GroupCoordinator: self, src: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: Optional[dict[str, bool]] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. + + all_gather_group: The group for the all-gather operation. If provided, + an optimization is enabled where each rank in the group sends a + slice of a tensor and the receiver reconstructs it using an + all-gather, which can improve performance. This is typically the + tensor-parallel group. + all_gather_tensors: A dictionary to specify which tensors should use + the all-gather optimization, which is only effective when + `all_gather_group` is provided. By default, this optimization is + on for any tensor whose size is divisible by the + `all_gather_group`'s world size. However, it should be disabled + for tensors that are not fully replicated across the group (e.g., + the residual tensor when sequence parallelism is enabled). This + dictionary allows overriding the default behavior on a per-tensor + basis. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return None - all_gather_size = (1 if all_gather_group is None else all_gather_group.world_size) all_gather_rank = (0 if all_gather_group is None else @@ -766,6 +805,8 @@ class GroupCoordinator: # send-allgather: send only a slice, then do allgather. use_all_gather = (all_gather_group is not None and tensor.numel() % all_gather_size == 0) + use_all_gather = all_gather_tensors.get(key, use_all_gather) \ + if all_gather_tensors else use_all_gather if use_all_gather: orig_shape = tensor.shape diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index b87c4fe09bb9..daee91ec404f 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -19,6 +19,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.cpu_model_runner import CPUModelRunner from vllm.v1.worker.gpu_worker import (Worker, init_worker_distributed_environment) +from vllm.v1.worker.utils import is_residual_scattered_for_sp logger = init_logger(__name__) @@ -107,18 +108,29 @@ class CPUWorker(Worker): scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: intermediate_tensors = None + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_input_tokens = self.model_runner._get_num_input_tokens( + num_scheduled_tokens) + all_gather_tensors = { + "residual": + not is_residual_scattered_for_sp(self.vllm_config, + num_input_tokens) + } if not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors)) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) if not get_pp_group().is_last_rank: assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) + get_pp_group().send_tensor_dict( + output.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors) return None assert isinstance(output, ModelRunnerOutput) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d4afaf51e6e8..d4d1f814afc0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -88,6 +88,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.worker.utils import is_residual_scattered_for_sp from .utils import (AttentionGroup, MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, @@ -1633,21 +1634,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size - enabled_sp = self.compilation_config.pass_config. \ - enable_sequence_parallelism - if enabled_sp: - # When sequence parallelism is enabled, we always pad num_tokens - # to be a multiple of tensor_parallel_size (tp) earlier - assert num_tokens % tp == 0 - is_residual_scattered = tp > 1 and enabled_sp \ - and num_tokens % tp == 0 + is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens) # When sequence parallelism is enabled, the "residual" tensor is sharded # across tensor parallel ranks, so each rank only needs its own slice. if sync_self: assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): - is_scattered = k == "residual" and is_residual_scattered + is_scattered = k == "residual" and is_rs copy_len = num_tokens // tp if is_scattered else \ num_tokens self.intermediate_tensors[k][:copy_len].copy_( @@ -1655,8 +1649,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return IntermediateTensors({ k: - v[:num_tokens // tp] - if k == "residual" and is_residual_scattered else v[:num_tokens] + v[:num_tokens // + tp] if k == "residual" and is_rs else v[:num_tokens] for k, v in self.intermediate_tensors.items() }) @@ -1741,6 +1735,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pooler_output=pooler_output, ) + def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH + and hasattr(self, "cudagraph_batch_sizes") + and self.cudagraph_batch_sizes + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + # Use CUDA graphs. + # Add padding to the batch size. + return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) + + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if (self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1): + return round_up(num_scheduled_tokens, tp_size) + return num_scheduled_tokens + def _preprocess( self, scheduler_output: "SchedulerOutput", @@ -1750,24 +1763,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Optional[IntermediateTensors], dict[str, Any]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use CUDA graphs. - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) - else: - # Eager mode. - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: - num_input_tokens = round_up(num_scheduled_tokens, tp_size) - else: - num_input_tokens = num_scheduled_tokens - + num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_input_tokens += num_pad @@ -2108,8 +2104,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert not self.is_pooling_model if not get_pp_group().is_last_rank: + all_gather_tensors = { + "residual": + not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens) + } get_pp_group().send_tensor_dict( - hidden_states.tensors, all_gather_group=get_tp_group()) + hidden_states.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors) logits = None else: sample_hidden_states = hidden_states[logits_indices] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 37dd431fd68f..6855526583f0 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -32,6 +32,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -428,10 +429,19 @@ class Worker(WorkerBase): ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_input_tokens = self.model_runner._get_num_input_tokens( + num_scheduled_tokens) + all_gather_tensors = { + "residual": + not is_residual_scattered_for_sp(self.vllm_config, + num_input_tokens) + } if forward_pass and not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors)) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) @@ -444,7 +454,8 @@ class Worker(WorkerBase): "external_launcher") and not get_pp_group().is_last_rank get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors) kv_connector_output = output.kv_connector_output if not kv_connector_output: diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index be05d02ff29f..5ac7470c1ac9 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import ModelConfig, SchedulerConfig +from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config @@ -288,3 +288,28 @@ def bind_kv_cache( for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache] + + +def is_residual_scattered_for_sp(vllm_config: VllmConfig, + num_input_tokens: int) -> bool: + """Check if the residual tensor is scattered for sequence parallelism. + + The residual tensor is scattered across tensor parallel ranks when sequence + parallelism and tensor parallelism is enabled, and the number of + input tokens is one of the compilation sizes. + """ + if not vllm_config.compilation_config.pass_config.\ + enable_sequence_parallelism: + return False + + tp = vllm_config.parallel_config.tensor_parallel_size + + if tp == 1: + return False + + # When sequence parallelism is enabled, we always pad num_input_tokens + # to be a multiple of tensor_parallel_size (tp) earlier. + assert num_input_tokens % tp == 0 + + # Currently, SP is only enabled for static size fx graphs. + return (num_input_tokens in vllm_config.compilation_config.compile_sizes)