mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-22 04:54:33 +08:00
low latency dispatch
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
48dcc72d7e
commit
da26dce7b2
@ -293,6 +293,7 @@ def low_latency_dispatch(
|
||||
num_experts,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
use_fp8=False,
|
||||
)
|
||||
)
|
||||
|
||||
@ -327,6 +328,65 @@ def low_latency_combine(
|
||||
return combined_hidden_states, event_overlap, hook
|
||||
|
||||
|
||||
def run_low_latency():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
|
||||
group = dist.group.WORLD
|
||||
num_experts = 8
|
||||
local_batch_size = 4
|
||||
batch_size = local_batch_size * world_size
|
||||
hidden_size = 2048
|
||||
local_num_experts = num_experts // group.size()
|
||||
x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16)
|
||||
topk = 2
|
||||
num_max_dispatch_tokens_per_rank = local_batch_size
|
||||
low_latency_get_buffer(
|
||||
group, num_max_dispatch_tokens_per_rank, hidden_size, num_experts
|
||||
)
|
||||
|
||||
expert_weights = torch.randn(
|
||||
local_batch_size,
|
||||
num_experts,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
topk_weights, topk_idx = torch.topk(expert_weights, topk, dim=1)
|
||||
|
||||
recv_hidden_states, recv_expert_count, handle, event, hook = low_latency_dispatch(
|
||||
x, topk_idx, num_max_dispatch_tokens_per_rank, num_experts
|
||||
)
|
||||
hook()
|
||||
|
||||
# Dispatch naive
|
||||
all_x = [torch.empty_like(x) for _ in range(world_size)]
|
||||
all_topk_idx = [torch.empty_like(topk_idx) for _ in range(world_size)]
|
||||
all_topk_weights = [torch.empty_like(topk_weights) for _ in range(world_size)]
|
||||
|
||||
dist.all_gather(all_x, x)
|
||||
dist.all_gather(all_topk_idx, topk_idx)
|
||||
dist.all_gather(all_topk_weights, topk_weights)
|
||||
|
||||
all_x = torch.cat(all_x, dim=0)
|
||||
all_topk_idx = torch.cat(all_topk_idx, dim=0)
|
||||
all_topk_weights = torch.cat(all_topk_weights, dim=0)
|
||||
|
||||
expert_tok_ids = [[] for _ in range(local_num_experts)]
|
||||
|
||||
expert_range_start = rank * local_num_experts
|
||||
for i, cnt in enumerate(recv_expert_count):
|
||||
for j in range(cnt):
|
||||
for k in range(batch_size):
|
||||
if torch.allclose(
|
||||
recv_hidden_states[i, j], all_x[k], rtol=0.0, atol=0.0
|
||||
):
|
||||
expert_tok_ids[i].append(k)
|
||||
assert i + expert_range_start in all_topk_idx[k]
|
||||
break
|
||||
|
||||
assert [len(tok_ids) for tok_ids in expert_tok_ids] == recv_expert_count.tolist()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
@ -335,3 +395,4 @@ if __name__ == "__main__":
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
run_high_throughput()
|
||||
run_low_latency()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user