diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 4ab8f3d938fc..66d4940c9cec 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -236,7 +236,8 @@ class CudaCommunicator(DeviceCommunicatorBase): input_size = input_.size() if sizes is not None: assert len(sizes) == world_size - assert input_.shape[dim] == sizes[self.rank_in_group] + assert input_.shape[dim] == sizes[self.rank_in_group], ( + f"{input_.shape[dim]} != {sizes[self.rank_in_group]}") output_size = (sum(sizes), ) + input_size[1:] else: output_size = (input_size[0] * world_size, ) + input_size[1:] diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feeaf..4686ba24e65f 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -26,10 +26,26 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL batchsize_forward_time: defaultdict = defaultdict(list) +def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], + max_num_tokens: int, + chunk_idx: int) -> list[int]: + dp_size = len(num_tokens_across_dp_cpu) + + local_size = [-1] * dp_size + for i in range(dp_size): + dp_tokens = num_tokens_across_dp_cpu[i] + local_size[i] = min(max_num_tokens, + dp_tokens - (max_num_tokens * chunk_idx)) + if local_size[i] <= 0: + local_size[i] = 1 # ensure lockstep even if done + return local_size + + @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor + local_sizes: Optional[list[int]] = None @staticmethod def num_tokens_across_dp(num_tokens: int, dp_size: int, @@ -78,6 +94,48 @@ class DPMetadata: cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) + @contextmanager + def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): + """ + Context manager to compute and temporarily set the per-rank local token + sizes for a specific chunk during chunked forward execution. + + This is necessary to ensure each DP (data parallel) rank processes its + designated portion of tokens in lockstep with others, even when the + token counts are uneven or some ranks have completed their input early. + + For chunked execution, we break up the total tokens on each rank into + multiple chunks (of at most `max_chunk_size_per_rank`), and for a given + `chunk_idx`, this context manager sets `self.local_sizes` to the number + of tokens to process in that chunk on each rank. + + It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the + number of tokens per rank, and calls `_compute_chunked_local_num_tokens` + to determine the chunk-wise split. + + `self.local_sizes` is only valid inside the context. + + Args: + max_chunk_size_per_rank: The max number of tokens each rank is + allowed to process in this chunk. + chunk_idx: The index of the chunk to compute sizes for. + """ + cu_sizes = self.cu_tokens_across_dp_cpu + num_tokens_across_dp_cpu = [ + (cu_sizes[i] - + cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item() + for i in range(len(cu_sizes)) + ] + self.local_sizes = _compute_chunked_local_num_tokens( + num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx) + try: + yield self.local_sizes + finally: + self.local_sizes = None + + def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: + return self.local_sizes + @dataclass class ForwardContext: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 02e1d1f1fd02..7fdb465c459d 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -4,7 +4,6 @@ from typing import Any, Optional import torch -import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context @@ -14,20 +13,8 @@ from vllm.model_executor.layers.fused_moe.utils import ( from vllm.utils.flashinfer import nvfp4_block_scale_interleave -def get_local_sizes(local_tokens): - cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu - sizes = [cu_sizes[0].item()] - for i in range(1, len(cu_sizes)): - sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item()) - max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE - sizes_chunked = [max_num_tokens] * len(sizes) - if local_tokens < max_num_tokens: - # When the number of local tokens is less than max_num_tokens, all other - # ranks will also have fewer than max_num_tokens. The remaining tokens - # are accounted for as residual. - sizes_chunked = [x % max_num_tokens for x in sizes] - - return sizes_chunked +def get_local_sizes(): + return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -90,7 +77,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_weights, topk_ids, a1q, a1q_scale = \ get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501 dim=0, - sizes=get_local_sizes(local_tokens)) + sizes=get_local_sizes()) a1_m, a1_n = a1q.shape a1q_scale = nvfp4_block_scale_interleave(a1q_scale) @@ -107,8 +94,5 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ['use_dp', 'local_tokens']) if use_dp: fused_expert_output = get_dp_group().reduce_scatterv( - fused_expert_output, - dim=0, - sizes=get_local_sizes(local_tokens), - ) + fused_expert_output, dim=0, sizes=get_local_sizes()) output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 76cedb3ed348..272b6ce67232 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1570,18 +1570,19 @@ class FusedMoE(torch.nn.Module): max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens num_tokens = full_hidden_states.size(0) - for chunk_start_ in range(0, max_tokens_across_dp, - moe_dp_chunk_size_per_rank): + for chunk_idx, chunk_start_ in enumerate( + range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)): chunk_start = chunk_start_ chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) - - process_chunk(chunk_start, - chunk_end, - skip_result_store=chunk_start_ >= num_tokens) + with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank, + chunk_idx): + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= num_tokens) return full_final_hidden_states