diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 3c8083e3dd0dd..592ca650a5546 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -10,7 +10,7 @@ import torch import torch.distributed as dist import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.logger import init_logger if TYPE_CHECKING: @@ -30,6 +30,44 @@ class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor + @staticmethod + def num_tokens_across_dp(num_tokens: int, dp_size: int, + dp_rank: int) -> torch.Tensor: + """ + Gather the num_tokens across all DP ranks and return results in a + CPU tensor of size dp_size. + """ + num_tokens_across_dp = [0] * dp_size + num_tokens_across_dp[dp_rank] = num_tokens + num_tokens_tensor = torch.tensor(num_tokens_across_dp, + device="cpu", + dtype=torch.int32) + from vllm.distributed.parallel_state import get_dp_group + dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + return num_tokens_tensor + + @staticmethod + def make(parallel_config: ParallelConfig, attn_metadata: Any, + num_tokens: int) -> "DPMetadata": + + assert parallel_config.data_parallel_size > 1 + dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank + if attn_metadata is not None and hasattr(attn_metadata, + "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + # for v1 attention backends or no attn_metadata + batchsize = num_tokens + + num_tokens_tensor = DPMetadata.num_tokens_across_dp( + batchsize, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_tensor) + cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) + return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) + @dataclass class ForwardContext: @@ -74,27 +112,8 @@ def set_forward_context(attn_metadata: Any, forward_start_time = time.perf_counter() dp_metadata: Optional[DPMetadata] = None if vllm_config.parallel_config.data_parallel_size > 1: - dp_size = vllm_config.parallel_config.data_parallel_size - dp_rank = vllm_config.parallel_config.data_parallel_rank - if attn_metadata is not None and hasattr(attn_metadata, - "num_prefill_tokens"): - # for v0 attention backends - batchsize = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens - else: - # for v1 attention backends or no attn_metadata - batchsize = num_tokens - num_tokens_across_dp = [0] * dp_size - num_tokens_across_dp[dp_rank] = batchsize - num_tokens_tensor = torch.tensor(num_tokens_across_dp, - device="cpu", - dtype=torch.int32) - from vllm.distributed.parallel_state import get_dp_group - dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) - max_tokens_across_dp_cpu = torch.max(num_tokens_tensor) - cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) - dp_metadata = DPMetadata(max_tokens_across_dp_cpu, - cu_tokens_across_dp_cpu) + dp_metadata = DPMetadata.make(vllm_config.parallel_config, + attn_metadata, num_tokens) global _forward_context prev_context = _forward_context diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 29b41e720852c..838a7c24b642f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -828,6 +828,21 @@ class FusedMoE(torch.nn.Module): self.quant_method.create_weights(layer=self, **moe_quant_params) + # Chunked all2all staging tensor + self.batched_hidden_states: Optional[torch.Tensor] = None + self.batched_router_logits: Optional[torch.Tensor] = None + if self.moe_parallel_config.use_pplx_kernels: + act_dtype = vllm_config.model_config.dtype + self.batched_hidden_states = torch.zeros( + (MOE_DP_CHUNK_SIZE, self.hidden_size), + dtype=act_dtype, + device=torch.cuda.current_device()) + + self.batched_router_logits = torch.zeros( + (MOE_DP_CHUNK_SIZE, self.global_num_experts), + dtype=act_dtype, + device=torch.cuda.current_device()) + @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -1217,18 +1232,39 @@ class FusedMoE(torch.nn.Module): def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): + assert self.batched_hidden_states is not None + assert self.batched_router_logits is not None + assert self.batched_hidden_states.dtype == full_hidden_states.dtype + assert self.batched_router_logits.dtype == full_router_logits.dtype + # Check size compatibility. + assert ( + self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)) + assert ( + self.batched_router_logits.size(-1) == full_router_logits.size(-1)) full_final_hidden_states = torch.empty_like(full_hidden_states) def process_chunk(chunk_start, chunk_end, skip_result_store=False): + chunk_size = chunk_end - chunk_start hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] + assert (self.batched_hidden_states.size(0) # type: ignore + >= chunk_size) + assert (self.batched_router_logits.size(0) # type: ignore + >= chunk_size) + staged_hidden_states = self.batched_hidden_states[: + chunk_size, :] # type: ignore + staged_router_logits = self.batched_router_logits[: + chunk_size, :] # type: ignore + staged_hidden_states.copy_(hidden_states, non_blocking=True) + staged_router_logits.copy_(router_logits, non_blocking=True) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, - x=hidden_states, - router_logits=router_logits, + x=staged_hidden_states, + router_logits=staged_router_logits, top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, @@ -1244,7 +1280,7 @@ class FusedMoE(torch.nn.Module): if not skip_result_store: full_final_hidden_states[chunk_start:chunk_end, :].copy_( - final_hidden_states) + final_hidden_states, non_blocking=True) ctx = get_forward_context() max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0bed44f732770..9f833cbb587d8 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -106,7 +106,6 @@ class CudaPlatformBase(Platform): def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config - compilation_config = vllm_config.compilation_config model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": @@ -154,16 +153,6 @@ class CudaPlatformBase(Platform): logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") - if (parallel_config.data_parallel_size > 1 - and compilation_config.use_cudagraph): - logger.info( - "Data Parallel: Forcing enforce eager to be True since DP is " - "currently not supported with CUDA Graphs.") - vllm_config.model_config.enforce_eager = True - compilation_config.use_cudagraph = False - # FIXME: inductor breaks cudagraph (from @bnell) - compilation_config.use_inductor = False - @classmethod def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5d5558162ab37..d1195bcfb27b9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -24,7 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, prepare_communication_buffer_for_model) -from vllm.forward_context import get_forward_context, set_forward_context +from vllm.forward_context import (DPMetadata, get_forward_context, + set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model @@ -1104,6 +1105,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): for k, v in self.intermediate_tensors.items() }) + def get_dp_padding(self, num_tokens: int): + dp_size = self.vllm_config.parallel_config.data_parallel_size + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + if dp_size == 1: + # Early exit. + return 0 + + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + num_tokens, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() + return max_tokens_across_dp_cpu - num_tokens + @torch.inference_mode() def execute_model( self, @@ -1141,6 +1154,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): else: num_input_tokens = num_scheduled_tokens + # Padding for DP + num_input_tokens += self.get_dp_padding(num_input_tokens) + # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order if self.is_multimodal_model: @@ -1658,6 +1674,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): skip_attn: bool = True, ) -> torch.Tensor: + # Padding for DP + num_tokens += self.get_dp_padding(num_tokens) + # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total.