Fix Flashinfer CUTLASS MOE Allgather (#21963)

Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
Shu Wang 2025-08-07 21:18:25 -05:00 committed by GitHub
parent a3b9c17b56
commit b2c8ce57c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 71 additions and 27 deletions

View File

@ -236,7 +236,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[dim] == sizes[self.rank_in_group]
assert input_.shape[dim] == sizes[self.rank_in_group], (
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}")
output_size = (sum(sizes), ) + input_size[1:]
else:
output_size = (input_size[0] * world_size, ) + input_size[1:]

View File

@ -26,10 +26,26 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list)
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
max_num_tokens: int,
chunk_idx: int) -> list[int]:
dp_size = len(num_tokens_across_dp_cpu)
local_size = [-1] * dp_size
for i in range(dp_size):
dp_tokens = num_tokens_across_dp_cpu[i]
local_size[i] = min(max_num_tokens,
dp_tokens - (max_num_tokens * chunk_idx))
if local_size[i] <= 0:
local_size[i] = 1 # ensure lockstep even if done
return local_size
@dataclass
class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor
local_sizes: Optional[list[int]] = None
@staticmethod
def num_tokens_across_dp(num_tokens: int, dp_size: int,
@ -78,6 +94,48 @@ class DPMetadata:
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
@contextmanager
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
"""
Context manager to compute and temporarily set the per-rank local token
sizes for a specific chunk during chunked forward execution.
This is necessary to ensure each DP (data parallel) rank processes its
designated portion of tokens in lockstep with others, even when the
token counts are uneven or some ranks have completed their input early.
For chunked execution, we break up the total tokens on each rank into
multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
`chunk_idx`, this context manager sets `self.local_sizes` to the number
of tokens to process in that chunk on each rank.
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
to determine the chunk-wise split.
`self.local_sizes` is only valid inside the context.
Args:
max_chunk_size_per_rank: The max number of tokens each rank is
allowed to process in this chunk.
chunk_idx: The index of the chunk to compute sizes for.
"""
cu_sizes = self.cu_tokens_across_dp_cpu
num_tokens_across_dp_cpu = [
(cu_sizes[i] -
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
for i in range(len(cu_sizes))
]
self.local_sizes = _compute_chunked_local_num_tokens(
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
try:
yield self.local_sizes
finally:
self.local_sizes = None
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
return self.local_sizes
@dataclass
class ForwardContext:

View File

@ -4,7 +4,6 @@ from typing import Any, Optional
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
@ -14,20 +13,8 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def get_local_sizes(local_tokens):
cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
sizes = [cu_sizes[0].item()]
for i in range(1, len(cu_sizes)):
sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item())
max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE
sizes_chunked = [max_num_tokens] * len(sizes)
if local_tokens < max_num_tokens:
# When the number of local tokens is less than max_num_tokens, all other
# ranks will also have fewer than max_num_tokens. The remaining tokens
# are accounted for as residual.
sizes_chunked = [x % max_num_tokens for x in sizes]
return sizes_chunked
def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
@ -90,7 +77,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights, topk_ids, a1q, a1q_scale = \
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
dim=0,
sizes=get_local_sizes(local_tokens))
sizes=get_local_sizes())
a1_m, a1_n = a1q.shape
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
@ -107,8 +94,5 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
['use_dp', 'local_tokens'])
if use_dp:
fused_expert_output = get_dp_group().reduce_scatterv(
fused_expert_output,
dim=0,
sizes=get_local_sizes(local_tokens),
)
fused_expert_output, dim=0, sizes=get_local_sizes())
output.copy_(fused_expert_output)

View File

@ -1570,18 +1570,19 @@ class FusedMoE(torch.nn.Module):
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
num_tokens = full_hidden_states.size(0)
for chunk_start_ in range(0, max_tokens_across_dp,
moe_dp_chunk_size_per_rank):
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
chunk_start = chunk_start_
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
max_tokens_across_dp)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
process_chunk(chunk_start,
chunk_end,
skip_result_store=chunk_start_ >= num_tokens)
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
chunk_idx):
process_chunk(chunk_start,
chunk_end,
skip_result_store=chunk_start_ >= num_tokens)
return full_final_hidden_states