mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 07:34:59 +08:00
Fix Flashinfer CUTLASS MOE Allgather (#21963)
Signed-off-by: Shu Wang <shuw@nvidia.com>
This commit is contained in:
parent
a3b9c17b56
commit
b2c8ce57c6
@ -236,7 +236,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
input_size = input_.size()
|
input_size = input_.size()
|
||||||
if sizes is not None:
|
if sizes is not None:
|
||||||
assert len(sizes) == world_size
|
assert len(sizes) == world_size
|
||||||
assert input_.shape[dim] == sizes[self.rank_in_group]
|
assert input_.shape[dim] == sizes[self.rank_in_group], (
|
||||||
|
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}")
|
||||||
output_size = (sum(sizes), ) + input_size[1:]
|
output_size = (sum(sizes), ) + input_size[1:]
|
||||||
else:
|
else:
|
||||||
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
||||||
|
|||||||
@ -26,10 +26,26 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
|||||||
batchsize_forward_time: defaultdict = defaultdict(list)
|
batchsize_forward_time: defaultdict = defaultdict(list)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
||||||
|
max_num_tokens: int,
|
||||||
|
chunk_idx: int) -> list[int]:
|
||||||
|
dp_size = len(num_tokens_across_dp_cpu)
|
||||||
|
|
||||||
|
local_size = [-1] * dp_size
|
||||||
|
for i in range(dp_size):
|
||||||
|
dp_tokens = num_tokens_across_dp_cpu[i]
|
||||||
|
local_size[i] = min(max_num_tokens,
|
||||||
|
dp_tokens - (max_num_tokens * chunk_idx))
|
||||||
|
if local_size[i] <= 0:
|
||||||
|
local_size[i] = 1 # ensure lockstep even if done
|
||||||
|
return local_size
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DPMetadata:
|
class DPMetadata:
|
||||||
max_tokens_across_dp_cpu: torch.Tensor
|
max_tokens_across_dp_cpu: torch.Tensor
|
||||||
cu_tokens_across_dp_cpu: torch.Tensor
|
cu_tokens_across_dp_cpu: torch.Tensor
|
||||||
|
local_sizes: Optional[list[int]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def num_tokens_across_dp(num_tokens: int, dp_size: int,
|
def num_tokens_across_dp(num_tokens: int, dp_size: int,
|
||||||
@ -78,6 +94,48 @@ class DPMetadata:
|
|||||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
||||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
|
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
|
||||||
|
"""
|
||||||
|
Context manager to compute and temporarily set the per-rank local token
|
||||||
|
sizes for a specific chunk during chunked forward execution.
|
||||||
|
|
||||||
|
This is necessary to ensure each DP (data parallel) rank processes its
|
||||||
|
designated portion of tokens in lockstep with others, even when the
|
||||||
|
token counts are uneven or some ranks have completed their input early.
|
||||||
|
|
||||||
|
For chunked execution, we break up the total tokens on each rank into
|
||||||
|
multiple chunks (of at most `max_chunk_size_per_rank`), and for a given
|
||||||
|
`chunk_idx`, this context manager sets `self.local_sizes` to the number
|
||||||
|
of tokens to process in that chunk on each rank.
|
||||||
|
|
||||||
|
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
|
||||||
|
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
|
||||||
|
to determine the chunk-wise split.
|
||||||
|
|
||||||
|
`self.local_sizes` is only valid inside the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_chunk_size_per_rank: The max number of tokens each rank is
|
||||||
|
allowed to process in this chunk.
|
||||||
|
chunk_idx: The index of the chunk to compute sizes for.
|
||||||
|
"""
|
||||||
|
cu_sizes = self.cu_tokens_across_dp_cpu
|
||||||
|
num_tokens_across_dp_cpu = [
|
||||||
|
(cu_sizes[i] -
|
||||||
|
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
|
||||||
|
for i in range(len(cu_sizes))
|
||||||
|
]
|
||||||
|
self.local_sizes = _compute_chunked_local_num_tokens(
|
||||||
|
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
|
||||||
|
try:
|
||||||
|
yield self.local_sizes
|
||||||
|
finally:
|
||||||
|
self.local_sizes = None
|
||||||
|
|
||||||
|
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
|
||||||
|
return self.local_sizes
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardContext:
|
class ForwardContext:
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.distributed import get_dp_group
|
from vllm.distributed import get_dp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@ -14,20 +13,8 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
|||||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||||
|
|
||||||
|
|
||||||
def get_local_sizes(local_tokens):
|
def get_local_sizes():
|
||||||
cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
|
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||||
sizes = [cu_sizes[0].item()]
|
|
||||||
for i in range(1, len(cu_sizes)):
|
|
||||||
sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item())
|
|
||||||
max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE
|
|
||||||
sizes_chunked = [max_num_tokens] * len(sizes)
|
|
||||||
if local_tokens < max_num_tokens:
|
|
||||||
# When the number of local tokens is less than max_num_tokens, all other
|
|
||||||
# ranks will also have fewer than max_num_tokens. The remaining tokens
|
|
||||||
# are accounted for as residual.
|
|
||||||
sizes_chunked = [x % max_num_tokens for x in sizes]
|
|
||||||
|
|
||||||
return sizes_chunked
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||||
@ -90,7 +77,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
topk_weights, topk_ids, a1q, a1q_scale = \
|
topk_weights, topk_ids, a1q, a1q_scale = \
|
||||||
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
|
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
|
||||||
dim=0,
|
dim=0,
|
||||||
sizes=get_local_sizes(local_tokens))
|
sizes=get_local_sizes())
|
||||||
a1_m, a1_n = a1q.shape
|
a1_m, a1_n = a1q.shape
|
||||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||||
|
|
||||||
@ -107,8 +94,5 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
['use_dp', 'local_tokens'])
|
['use_dp', 'local_tokens'])
|
||||||
if use_dp:
|
if use_dp:
|
||||||
fused_expert_output = get_dp_group().reduce_scatterv(
|
fused_expert_output = get_dp_group().reduce_scatterv(
|
||||||
fused_expert_output,
|
fused_expert_output, dim=0, sizes=get_local_sizes())
|
||||||
dim=0,
|
|
||||||
sizes=get_local_sizes(local_tokens),
|
|
||||||
)
|
|
||||||
output.copy_(fused_expert_output)
|
output.copy_(fused_expert_output)
|
||||||
|
|||||||
@ -1570,18 +1570,19 @@ class FusedMoE(torch.nn.Module):
|
|||||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||||
num_tokens = full_hidden_states.size(0)
|
num_tokens = full_hidden_states.size(0)
|
||||||
for chunk_start_ in range(0, max_tokens_across_dp,
|
for chunk_idx, chunk_start_ in enumerate(
|
||||||
moe_dp_chunk_size_per_rank):
|
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
|
||||||
chunk_start = chunk_start_
|
chunk_start = chunk_start_
|
||||||
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
||||||
max_tokens_across_dp)
|
max_tokens_across_dp)
|
||||||
# clamp start and end
|
# clamp start and end
|
||||||
chunk_start = min(chunk_start, num_tokens - 1)
|
chunk_start = min(chunk_start, num_tokens - 1)
|
||||||
chunk_end = min(chunk_end, num_tokens)
|
chunk_end = min(chunk_end, num_tokens)
|
||||||
|
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
|
||||||
process_chunk(chunk_start,
|
chunk_idx):
|
||||||
chunk_end,
|
process_chunk(chunk_start,
|
||||||
skip_result_store=chunk_start_ >= num_tokens)
|
chunk_end,
|
||||||
|
skip_result_store=chunk_start_ >= num_tokens)
|
||||||
|
|
||||||
return full_final_hidden_states
|
return full_final_hidden_states
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user