[Core] Enable CUDA graphs for DP + All2All kernels (#18724)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-05-28 18:55:30 -04:00 committed by GitHub
parent 6dbe5b5c93
commit 7951d78738
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 37 deletions

View File

@ -10,7 +10,7 @@ import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config import ParallelConfig, VllmConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
@ -30,6 +30,44 @@ class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor
@staticmethod
def num_tokens_across_dp(num_tokens: int, dp_size: int,
dp_rank: int) -> torch.Tensor:
"""
Gather the num_tokens across all DP ranks and return results in a
CPU tensor of size dp_size.
"""
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
return num_tokens_tensor
@staticmethod
def make(parallel_config: ParallelConfig, attn_metadata: Any,
num_tokens: int) -> "DPMetadata":
assert parallel_config.data_parallel_size > 1
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
if attn_metadata is not None and hasattr(attn_metadata,
"num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens
num_tokens_tensor = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
@dataclass
class ForwardContext:
@ -74,27 +112,8 @@ def set_forward_context(attn_metadata: Any,
forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1:
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
if attn_metadata is not None and hasattr(attn_metadata,
"num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = batchsize
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
cu_tokens_across_dp_cpu)
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens)
global _forward_context
prev_context = _forward_context

View File

@ -828,6 +828,21 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params)
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if self.moe_parallel_config.use_pplx_kernels:
act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=act_dtype,
device=torch.cuda.current_device())
self.batched_router_logits = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
dtype=act_dtype,
device=torch.cuda.current_device())
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@ -1217,18 +1232,39 @@ class FusedMoE(torch.nn.Module):
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
assert self.batched_router_logits.dtype == full_router_logits.dtype
# Check size compatibility.
assert (
self.batched_hidden_states.size(-1) == full_hidden_states.size(-1))
assert (
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
full_final_hidden_states = torch.empty_like(full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
assert (self.batched_hidden_states.size(0) # type: ignore
>= chunk_size)
assert (self.batched_router_logits.size(0) # type: ignore
>= chunk_size)
staged_hidden_states = self.batched_hidden_states[:
chunk_size, :] # type: ignore
staged_router_logits = self.batched_router_logits[:
chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
x=staged_hidden_states,
router_logits=staged_router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
@ -1244,7 +1280,7 @@ class FusedMoE(torch.nn.Module):
if not skip_result_store:
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states)
final_hidden_states, non_blocking=True)
ctx = get_forward_context()
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu

View File

@ -106,7 +106,6 @@ class CudaPlatformBase(Platform):
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
if parallel_config.worker_cls == "auto":
@ -154,16 +153,6 @@ class CudaPlatformBase(Platform):
logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.")
if (parallel_config.data_parallel_size > 1
and compilation_config.use_cudagraph):
logger.info(
"Data Parallel: Forcing enforce eager to be True since DP is "
"currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False
# FIXME: inductor breaks cudagraph (from @bnell)
compilation_config.use_inductor = False
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None

View File

@ -24,7 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture,
prepare_communication_buffer_for_model)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model
@ -1104,6 +1105,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
def get_dp_padding(self, num_tokens: int):
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
if dp_size == 1:
# Early exit.
return 0
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
return max_tokens_across_dp_cpu - num_tokens
@torch.inference_mode()
def execute_model(
self,
@ -1141,6 +1154,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP
num_input_tokens += self.get_dp_padding(num_input_tokens)
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if self.is_multimodal_model:
@ -1658,6 +1674,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
skip_attn: bool = True,
) -> torch.Tensor:
# Padding for DP
num_tokens += self.get_dp_padding(num_tokens)
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.