mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 05:05:46 +08:00
add combine example
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
4e2abe99b7
commit
e3e2bb3865
@ -120,6 +120,25 @@ def dispatch_forward(
|
||||
)
|
||||
|
||||
|
||||
def combine_forward(
|
||||
x: torch.Tensor, handle: tuple, previous_event: EventOverlap | None = None
|
||||
) -> tuple[torch.Tensor, EventOverlap]:
|
||||
global _buffer
|
||||
|
||||
# Do MoE combine
|
||||
# For more advanced usages, please refer to the docs of the `combine` function
|
||||
combined_x, _, event = _buffer.combine(
|
||||
x,
|
||||
handle,
|
||||
async_finish=True,
|
||||
previous_event=previous_event,
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
)
|
||||
|
||||
# For event management, please refer to the docs of the `EventOverlap` class
|
||||
return combined_x, event
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
@ -134,7 +153,7 @@ if __name__ == "__main__":
|
||||
num_experts = 8
|
||||
local_batch_size = 4
|
||||
batch_size = local_batch_size * world_size
|
||||
hidden_size = 128
|
||||
hidden_size = 32
|
||||
local_num_experts = num_experts // group.size()
|
||||
x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16)
|
||||
hidden_bytes = get_hidden_bytes(x)
|
||||
@ -163,7 +182,6 @@ if __name__ == "__main__":
|
||||
topk_weights,
|
||||
num_experts,
|
||||
)
|
||||
# print(f"rank {rank} recv_x: {recv_x.shape=}")
|
||||
recv_topk_global_idx = recv_topk_idx + torch.where(
|
||||
recv_topk_idx == -1,
|
||||
0,
|
||||
@ -198,9 +216,7 @@ if __name__ == "__main__":
|
||||
for prev in range(j):
|
||||
# Assert previous tokens are not activated on this rank
|
||||
assert recv_topk_idx[recv_i, prev] == -1
|
||||
assert recv_topk_weights[recv_i, prev] == 0.0, (
|
||||
f"{recv_topk_weights[recv_i, prev]=}"
|
||||
)
|
||||
assert recv_topk_weights[recv_i, prev] == 0.0
|
||||
expert_inputs_ground_truth.append(all_x[i])
|
||||
|
||||
assert (
|
||||
@ -214,3 +230,22 @@ if __name__ == "__main__":
|
||||
|
||||
expert_inputs_ground_truth = torch.stack(expert_inputs_ground_truth, dim=0)
|
||||
torch.testing.assert_close(expert_inputs_ground_truth, recv_x, atol=0.0, rtol=0.0)
|
||||
|
||||
# NOTE: This dispatch result cannot be directly fed into the GEMM kernel, it
|
||||
# needs another shuffuling/expansion process to cluster the input to each expert.
|
||||
# The results to combine will be reduced on the local rank first.
|
||||
|
||||
# Combine
|
||||
results = recv_x
|
||||
combined_results, event = combine_forward(results, handle, event)
|
||||
|
||||
for i in range(local_batch_size):
|
||||
appear_count = 0.0
|
||||
for j in range(world_size):
|
||||
expert_range_start = j * local_num_experts
|
||||
expert_range_end = (j + 1) * local_num_experts
|
||||
if any(
|
||||
(expert_range_start <= topk_idx[i]) & (topk_idx[i] < expert_range_end)
|
||||
):
|
||||
appear_count += 1.0
|
||||
torch.testing.assert_close(combined_results[i], x[i] * appear_count)
|
||||
|
||||
@ -136,7 +136,7 @@ def run_batched_deepgemm_contiguous_bf16(
|
||||
aligned_end = start + aligned_ms[i]
|
||||
expert_ids[start:actual_end] = i
|
||||
expert_ids[actual_end:aligned_end] = -1
|
||||
reference_output[start:aligned_end] = x[start:aligned_end] @ weight[i].t()
|
||||
reference_output[start:actual_end] = x[start:actual_end] @ weight[i].t()
|
||||
start = aligned_end
|
||||
|
||||
output = torch.zeros(
|
||||
@ -152,6 +152,7 @@ def run_batched_deepgemm_contiguous_bf16(
|
||||
output,
|
||||
expert_ids,
|
||||
)
|
||||
output = output * (expert_ids != -1).unsqueeze(1)
|
||||
torch.testing.assert_close(output, reference_output)
|
||||
|
||||
|
||||
@ -351,6 +352,6 @@ def run_triton_group_gemm_masked_bf16(
|
||||
|
||||
# run_batched_deepgemm_masked_fp8(512, 8, 1024, 512)
|
||||
run_batched_deepgemm_contiguous_bf16(512, 8, 1024, 512)
|
||||
run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
|
||||
run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
|
||||
run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)
|
||||
# run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
|
||||
# run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
|
||||
# run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user