diff --git a/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py b/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py index 28a391da99789..dcac696407422 100644 --- a/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py +++ b/vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py @@ -120,6 +120,25 @@ def dispatch_forward( ) +def combine_forward( + x: torch.Tensor, handle: tuple, previous_event: EventOverlap | None = None +) -> tuple[torch.Tensor, EventOverlap]: + global _buffer + + # Do MoE combine + # For more advanced usages, please refer to the docs of the `combine` function + combined_x, _, event = _buffer.combine( + x, + handle, + async_finish=True, + previous_event=previous_event, + allocate_on_comm_stream=previous_event is not None, + ) + + # For event management, please refer to the docs of the `EventOverlap` class + return combined_x, event + + if __name__ == "__main__": torch.distributed.init_process_group( backend="nccl", @@ -134,7 +153,7 @@ if __name__ == "__main__": num_experts = 8 local_batch_size = 4 batch_size = local_batch_size * world_size - hidden_size = 128 + hidden_size = 32 local_num_experts = num_experts // group.size() x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16) hidden_bytes = get_hidden_bytes(x) @@ -163,7 +182,6 @@ if __name__ == "__main__": topk_weights, num_experts, ) - # print(f"rank {rank} recv_x: {recv_x.shape=}") recv_topk_global_idx = recv_topk_idx + torch.where( recv_topk_idx == -1, 0, @@ -198,9 +216,7 @@ if __name__ == "__main__": for prev in range(j): # Assert previous tokens are not activated on this rank assert recv_topk_idx[recv_i, prev] == -1 - assert recv_topk_weights[recv_i, prev] == 0.0, ( - f"{recv_topk_weights[recv_i, prev]=}" - ) + assert recv_topk_weights[recv_i, prev] == 0.0 expert_inputs_ground_truth.append(all_x[i]) assert ( @@ -214,3 +230,22 @@ if __name__ == "__main__": expert_inputs_ground_truth = torch.stack(expert_inputs_ground_truth, dim=0) torch.testing.assert_close(expert_inputs_ground_truth, recv_x, atol=0.0, rtol=0.0) + + # NOTE: This dispatch result cannot be directly fed into the GEMM kernel, it + # needs another shuffuling/expansion process to cluster the input to each expert. + # The results to combine will be reduced on the local rank first. + + # Combine + results = recv_x + combined_results, event = combine_forward(results, handle, event) + + for i in range(local_batch_size): + appear_count = 0.0 + for j in range(world_size): + expert_range_start = j * local_num_experts + expert_range_end = (j + 1) * local_num_experts + if any( + (expert_range_start <= topk_idx[i]) & (topk_idx[i] < expert_range_end) + ): + appear_count += 1.0 + torch.testing.assert_close(combined_results[i], x[i] * appear_count) diff --git a/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py b/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py index d73daaa846a61..b9e438d6bd12a 100644 --- a/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py +++ b/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py @@ -136,7 +136,7 @@ def run_batched_deepgemm_contiguous_bf16( aligned_end = start + aligned_ms[i] expert_ids[start:actual_end] = i expert_ids[actual_end:aligned_end] = -1 - reference_output[start:aligned_end] = x[start:aligned_end] @ weight[i].t() + reference_output[start:actual_end] = x[start:actual_end] @ weight[i].t() start = aligned_end output = torch.zeros( @@ -152,6 +152,7 @@ def run_batched_deepgemm_contiguous_bf16( output, expert_ids, ) + output = output * (expert_ids != -1).unsqueeze(1) torch.testing.assert_close(output, reference_output) @@ -351,6 +352,6 @@ def run_triton_group_gemm_masked_bf16( # run_batched_deepgemm_masked_fp8(512, 8, 1024, 512) run_batched_deepgemm_contiguous_bf16(512, 8, 1024, 512) -run_batched_deepgemm_masked_bf16(512, 8, 1024, 512) -run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4) -run_triton_group_gemm_masked_bf16(512, 8, 1024, 512) +# run_batched_deepgemm_masked_bf16(512, 8, 1024, 512) +# run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4) +# run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)