mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:25:01 +08:00
Fix Flashinfer CUTLASS MOE Allgather (#21963)
Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
parent
a3b9c17b56
commit
b2c8ce57c6
@ -236,7 +236,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
input_size = input_.size()
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_.shape[dim] == sizes[self.rank_in_group]
|
||||
assert input_.shape[dim] == sizes[self.rank_in_group], (
|
||||
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}")
|
||||
output_size = (sum(sizes), ) + input_size[1:]
|
||||
else:
|
||||
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
||||
|
||||
@ -26,10 +26,26 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
||||
batchsize_forward_time: defaultdict = defaultdict(list)
|
||||
|
||||
|
||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
||||
max_num_tokens: int,
|
||||
chunk_idx: int) -> list[int]:
|
||||
dp_size = len(num_tokens_across_dp_cpu)
|
||||
|
||||
local_size = [-1] * dp_size
|
||||
for i in range(dp_size):
|
||||
dp_tokens = num_tokens_across_dp_cpu[i]
|
||||
local_size[i] = min(max_num_tokens,
|
||||
dp_tokens - (max_num_tokens * chunk_idx))
|
||||
if local_size[i] <= 0:
|
||||
local_size[i] = 1 # ensure lockstep even if done
|
||||
return local_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPMetadata:
|
||||
max_tokens_across_dp_cpu: torch.Tensor
|
||||
cu_tokens_across_dp_cpu: torch.Tensor
|
||||
local_sizes: Optional[list[int]] = None
|
||||
|
||||
@staticmethod
|
||||
def num_tokens_across_dp(num_tokens: int, dp_size: int,
|
||||
@ -78,6 +94,48 @@ class DPMetadata:
|
||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
|
||||
|
||||
@contextmanager
|
||||
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
|
||||
"""
|
||||
Context manager to compute and temporarily set the per-rank local token
|
||||
sizes for a specific chunk during chunked forward execution.
|
||||
|
||||
This is necessary to ensure each DP (data parallel) rank processes its
|
||||
designated portion of tokens in lockstep with others, even when the
|
||||
token counts are uneven or some ranks have completed their input early.
|
||||
|
||||
For chunked execution, we break up the total tokens on each rank into
|
||||
multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
|
||||
`chunk_idx`, this context manager sets `self.local_sizes` to the number
|
||||
of tokens to process in that chunk on each rank.
|
||||
|
||||
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
|
||||
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
|
||||
to determine the chunk-wise split.
|
||||
|
||||
`self.local_sizes` is only valid inside the context.
|
||||
|
||||
Args:
|
||||
max_chunk_size_per_rank: The max number of tokens each rank is
|
||||
allowed to process in this chunk.
|
||||
chunk_idx: The index of the chunk to compute sizes for.
|
||||
"""
|
||||
cu_sizes = self.cu_tokens_across_dp_cpu
|
||||
num_tokens_across_dp_cpu = [
|
||||
(cu_sizes[i] -
|
||||
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
|
||||
for i in range(len(cu_sizes))
|
||||
]
|
||||
self.local_sizes = _compute_chunked_local_num_tokens(
|
||||
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
|
||||
return self.local_sizes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
@ -14,20 +13,8 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
def get_local_sizes(local_tokens):
|
||||
cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
|
||||
sizes = [cu_sizes[0].item()]
|
||||
for i in range(1, len(cu_sizes)):
|
||||
sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item())
|
||||
max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
sizes_chunked = [max_num_tokens] * len(sizes)
|
||||
if local_tokens < max_num_tokens:
|
||||
# When the number of local tokens is less than max_num_tokens, all other
|
||||
# ranks will also have fewer than max_num_tokens. The remaining tokens
|
||||
# are accounted for as residual.
|
||||
sizes_chunked = [x % max_num_tokens for x in sizes]
|
||||
|
||||
return sizes_chunked
|
||||
def get_local_sizes():
|
||||
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
|
||||
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
@ -90,7 +77,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights, topk_ids, a1q, a1q_scale = \
|
||||
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
|
||||
dim=0,
|
||||
sizes=get_local_sizes(local_tokens))
|
||||
sizes=get_local_sizes())
|
||||
a1_m, a1_n = a1q.shape
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
@ -107,8 +94,5 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
['use_dp', 'local_tokens'])
|
||||
if use_dp:
|
||||
fused_expert_output = get_dp_group().reduce_scatterv(
|
||||
fused_expert_output,
|
||||
dim=0,
|
||||
sizes=get_local_sizes(local_tokens),
|
||||
)
|
||||
fused_expert_output, dim=0, sizes=get_local_sizes())
|
||||
output.copy_(fused_expert_output)
|
||||
|
||||
@ -1570,18 +1570,19 @@ class FusedMoE(torch.nn.Module):
|
||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
for chunk_start_ in range(0, max_tokens_across_dp,
|
||||
moe_dp_chunk_size_per_rank):
|
||||
for chunk_idx, chunk_start_ in enumerate(
|
||||
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
||||
max_tokens_across_dp)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
|
||||
process_chunk(chunk_start,
|
||||
chunk_end,
|
||||
skip_result_store=chunk_start_ >= num_tokens)
|
||||
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
|
||||
chunk_idx):
|
||||
process_chunk(chunk_start,
|
||||
chunk_end,
|
||||
skip_result_store=chunk_start_ >= num_tokens)
|
||||
|
||||
return full_final_hidden_states
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user