mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 11:47:09 +08:00
low latency combine
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
da26dce7b2
commit
99e2379b16
@ -371,8 +371,8 @@ def run_low_latency():
|
|||||||
all_topk_idx = torch.cat(all_topk_idx, dim=0)
|
all_topk_idx = torch.cat(all_topk_idx, dim=0)
|
||||||
all_topk_weights = torch.cat(all_topk_weights, dim=0)
|
all_topk_weights = torch.cat(all_topk_weights, dim=0)
|
||||||
|
|
||||||
|
# Verification
|
||||||
expert_tok_ids = [[] for _ in range(local_num_experts)]
|
expert_tok_ids = [[] for _ in range(local_num_experts)]
|
||||||
|
|
||||||
expert_range_start = rank * local_num_experts
|
expert_range_start = rank * local_num_experts
|
||||||
for i, cnt in enumerate(recv_expert_count):
|
for i, cnt in enumerate(recv_expert_count):
|
||||||
for j in range(cnt):
|
for j in range(cnt):
|
||||||
@ -386,6 +386,19 @@ def run_low_latency():
|
|||||||
|
|
||||||
assert [len(tok_ids) for tok_ids in expert_tok_ids] == recv_expert_count.tolist()
|
assert [len(tok_ids) for tok_ids in expert_tok_ids] == recv_expert_count.tolist()
|
||||||
|
|
||||||
|
# Combine
|
||||||
|
combined_hidden_states, event_overlap, hook = low_latency_combine(
|
||||||
|
recv_hidden_states, topk_idx, topk_weights, handle
|
||||||
|
)
|
||||||
|
hook()
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
combined_hidden_states.to(torch.float32),
|
||||||
|
x * topk_weights.sum(dim=-1).unsqueeze(-1),
|
||||||
|
atol=1e-2,
|
||||||
|
rtol=1e-2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user