[Core] Enable CUDA graphs for DP + All2All kernels (#18724)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-05-28 18:55:30 -04:00 committed by GitHub
parent 6dbe5b5c93
commit 7951d78738
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 37 deletions

View File

@ -10,7 +10,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
@ -30,6 +30,44 @@ class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor
@staticmethod
def num_tokens_across_dp(num_tokens: int, dp_size: int,
dp_rank: int) -> torch.Tensor:
"""
Gather the num_tokens across all DP ranks and return results in a
CPU tensor of size dp_size.
"""
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = num_tokens
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
return num_tokens_tensor
@staticmethod
def make(parallel_config: ParallelConfig, attn_metadata: Any,
num_tokens: int) -> "DPMetadata":
assert parallel_config.data_parallel_size > 1
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
if attn_metadata is not None and hasattr(attn_metadata,
"num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens
num_tokens_tensor = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
@dataclass @dataclass
class ForwardContext: class ForwardContext:
@ -74,27 +112,8 @@ def set_forward_context(attn_metadata: Any,
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
if vllm_config.parallel_config.data_parallel_size > 1: if vllm_config.parallel_config.data_parallel_size > 1:
dp_size = vllm_config.parallel_config.data_parallel_size dp_metadata = DPMetadata.make(vllm_config.parallel_config,
dp_rank = vllm_config.parallel_config.data_parallel_rank attn_metadata, num_tokens)
if attn_metadata is not None and hasattr(attn_metadata,
"num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = batchsize
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
cu_tokens_across_dp_cpu)
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context

View File

@ -828,6 +828,21 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if self.moe_parallel_config.use_pplx_kernels:
act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=act_dtype,
device=torch.cuda.current_device())
self.batched_router_logits = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
dtype=act_dtype,
device=torch.cuda.current_device())
@property @property
def tp_size(self): def tp_size(self):
return self.moe_parallel_config.tp_size return self.moe_parallel_config.tp_size
@ -1217,18 +1232,39 @@ class FusedMoE(torch.nn.Module):
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
assert self.batched_router_logits.dtype == full_router_logits.dtype
# Check size compatibility.
assert (
self.batched_hidden_states.size(-1) == full_hidden_states.size(-1))
assert (
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
full_final_hidden_states = torch.empty_like(full_hidden_states) full_final_hidden_states = torch.empty_like(full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False): def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :] hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :]
assert (self.batched_hidden_states.size(0) # type: ignore
>= chunk_size)
assert (self.batched_router_logits.size(0) # type: ignore
>= chunk_size)
staged_hidden_states = self.batched_hidden_states[:
chunk_size, :] # type: ignore
staged_router_logits = self.batched_router_logits[:
chunk_size, :] # type: ignore
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=hidden_states, x=staged_hidden_states,
router_logits=router_logits, router_logits=staged_router_logits,
top_k=self.top_k, top_k=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
@ -1244,7 +1280,7 @@ class FusedMoE(torch.nn.Module):
if not skip_result_store: if not skip_result_store:
full_final_hidden_states[chunk_start:chunk_end, :].copy_( full_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states) final_hidden_states, non_blocking=True)
ctx = get_forward_context() ctx = get_forward_context()
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu

View File

@ -106,7 +106,6 @@ class CudaPlatformBase(Platform):
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
@ -154,16 +153,6 @@ class CudaPlatformBase(Platform):
logger.info( logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.") "Forcing kv cache block size to 64 for FlashMLA backend.")
if (parallel_config.data_parallel_size > 1
and compilation_config.use_cudagraph):
logger.info(
"Data Parallel: Forcing enforce eager to be True since DP is "
"currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False
# FIXME: inductor breaks cudagraph (from @bnell)
compilation_config.use_inductor = False
@classmethod @classmethod
def get_current_memory_usage(cls, def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None device: Optional[torch.types.Device] = None

View File

@ -24,7 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, get_pp_group, get_tp_group, graph_capture,
prepare_communication_buffer_for_model) prepare_communication_buffer_for_model)
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model from vllm.model_executor.model_loader import TensorizerLoader, get_model
@ -1104,6 +1105,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items() for k, v in self.intermediate_tensors.items()
}) })
def get_dp_padding(self, num_tokens: int):
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
if dp_size == 1:
# Early exit.
return 0
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
return max_tokens_across_dp_cpu - num_tokens
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
@ -1141,6 +1154,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
num_input_tokens = num_scheduled_tokens num_input_tokens = num_scheduled_tokens
# Padding for DP
num_input_tokens += self.get_dp_padding(num_input_tokens)
# _prepare_inputs may reorder the batch, so we must gather multi # _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order # modal outputs after that to ensure the correct order
if self.is_multimodal_model: if self.is_multimodal_model:
@ -1658,6 +1674,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
skip_attn: bool = True, skip_attn: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
# Padding for DP
num_tokens += self.get_dp_padding(num_tokens)
# Set num_scheduled_tokens based on num_tokens and max_num_seqs # Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively # for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total. # has num_tokens in total.