diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 0af16bbc0007..f192be1c40d5 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -47,8 +47,12 @@ class DPMetadata: return num_tokens_tensor @staticmethod - def make(parallel_config: ParallelConfig, attn_metadata: Any, - num_tokens: int) -> "DPMetadata": + def make( + parallel_config: ParallelConfig, + attn_metadata: Any, + num_tokens: int, + num_tokens_across_dp: Optional[torch.Tensor] = None + ) -> "DPMetadata": assert parallel_config.data_parallel_size > 1 dp_size = parallel_config.data_parallel_size @@ -62,10 +66,15 @@ class DPMetadata: # for v1 attention backends or no attn_metadata batchsize = num_tokens - num_tokens_tensor = DPMetadata.num_tokens_across_dp( - batchsize, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_tensor) - cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) + # If num_tokens_across_dp is None, it will be computed by all_reduce + # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize + assert (num_tokens_across_dp is None + or num_tokens_across_dp[dp_rank] == batchsize) + if num_tokens_across_dp is None: + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + batchsize, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) + cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) @@ -101,7 +110,8 @@ def get_forward_context() -> ForwardContext: def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: Optional[int] = None): + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -114,7 +124,8 @@ def set_forward_context(attn_metadata: Any, if vllm_config.parallel_config.data_parallel_size > 1 and ( attn_metadata is not None or num_tokens is not None): dp_metadata = DPMetadata.make(vllm_config.parallel_config, - attn_metadata, num_tokens or 0) + attn_metadata, num_tokens or 0, + num_tokens_across_dp) global _forward_context prev_context = _forward_context diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b6fa68ab0925..4bc825ccb335 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1111,17 +1111,30 @@ class GPUModelRunner(LoRAModelRunnerMixin): for k, v in self.intermediate_tensors.items() }) - def get_dp_padding(self, num_tokens: int): + def get_dp_padding(self, + num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: dp_size = self.vllm_config.parallel_config.data_parallel_size dp_rank = self.vllm_config.parallel_config.data_parallel_rank - if dp_size == 1: + + # For DP: Don't pad when setting enforce_eager. + # This lets us set enforce_eager on the prefiller in a P/D setup and + # still use CUDA graphs (enabled by this padding) on the decoder. + # + # TODO(tms) : There are many cases where padding is enabled for + # prefills, causing unnecessary and excessive padding of activations. + + if dp_size == 1 or self.vllm_config.model_config.enforce_eager: # Early exit. - return 0 + return 0, None num_tokens_across_dp = DPMetadata.num_tokens_across_dp( num_tokens, dp_size, dp_rank) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - return max_tokens_across_dp_cpu - num_tokens + num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * + dp_size, + device="cpu", + dtype=torch.int32) + return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding @torch.inference_mode() def execute_model( @@ -1161,7 +1174,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_input_tokens = num_scheduled_tokens # Padding for DP - num_input_tokens += self.get_dp_padding(num_input_tokens) + num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) + num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1208,7 +1222,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config, - num_tokens=num_input_tokens): + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp): self.maybe_setup_kv_connector(scheduler_output) model_output = self.model( @@ -1681,7 +1696,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) -> torch.Tensor: # Padding for DP - num_tokens += self.get_dp_padding(num_tokens) + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens += num_pad # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -1747,9 +1763,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_tokens): + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): outputs = model( input_ids=input_ids, positions=positions,