mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 23:17:12 +08:00
low latency dispatch
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
48dcc72d7e
commit
da26dce7b2
@ -293,6 +293,7 @@ def low_latency_dispatch(
|
|||||||
num_experts,
|
num_experts,
|
||||||
async_finish=False,
|
async_finish=False,
|
||||||
return_recv_hook=True,
|
return_recv_hook=True,
|
||||||
|
use_fp8=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -327,6 +328,65 @@ def low_latency_combine(
|
|||||||
return combined_hidden_states, event_overlap, hook
|
return combined_hidden_states, event_overlap, hook
|
||||||
|
|
||||||
|
|
||||||
|
def run_low_latency():
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
rank = int(os.environ.get("RANK", 0))
|
||||||
|
|
||||||
|
group = dist.group.WORLD
|
||||||
|
num_experts = 8
|
||||||
|
local_batch_size = 4
|
||||||
|
batch_size = local_batch_size * world_size
|
||||||
|
hidden_size = 2048
|
||||||
|
local_num_experts = num_experts // group.size()
|
||||||
|
x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16)
|
||||||
|
topk = 2
|
||||||
|
num_max_dispatch_tokens_per_rank = local_batch_size
|
||||||
|
low_latency_get_buffer(
|
||||||
|
group, num_max_dispatch_tokens_per_rank, hidden_size, num_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
expert_weights = torch.randn(
|
||||||
|
local_batch_size,
|
||||||
|
num_experts,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
topk_weights, topk_idx = torch.topk(expert_weights, topk, dim=1)
|
||||||
|
|
||||||
|
recv_hidden_states, recv_expert_count, handle, event, hook = low_latency_dispatch(
|
||||||
|
x, topk_idx, num_max_dispatch_tokens_per_rank, num_experts
|
||||||
|
)
|
||||||
|
hook()
|
||||||
|
|
||||||
|
# Dispatch naive
|
||||||
|
all_x = [torch.empty_like(x) for _ in range(world_size)]
|
||||||
|
all_topk_idx = [torch.empty_like(topk_idx) for _ in range(world_size)]
|
||||||
|
all_topk_weights = [torch.empty_like(topk_weights) for _ in range(world_size)]
|
||||||
|
|
||||||
|
dist.all_gather(all_x, x)
|
||||||
|
dist.all_gather(all_topk_idx, topk_idx)
|
||||||
|
dist.all_gather(all_topk_weights, topk_weights)
|
||||||
|
|
||||||
|
all_x = torch.cat(all_x, dim=0)
|
||||||
|
all_topk_idx = torch.cat(all_topk_idx, dim=0)
|
||||||
|
all_topk_weights = torch.cat(all_topk_weights, dim=0)
|
||||||
|
|
||||||
|
expert_tok_ids = [[] for _ in range(local_num_experts)]
|
||||||
|
|
||||||
|
expert_range_start = rank * local_num_experts
|
||||||
|
for i, cnt in enumerate(recv_expert_count):
|
||||||
|
for j in range(cnt):
|
||||||
|
for k in range(batch_size):
|
||||||
|
if torch.allclose(
|
||||||
|
recv_hidden_states[i, j], all_x[k], rtol=0.0, atol=0.0
|
||||||
|
):
|
||||||
|
expert_tok_ids[i].append(k)
|
||||||
|
assert i + expert_range_start in all_topk_idx[k]
|
||||||
|
break
|
||||||
|
|
||||||
|
assert [len(tok_ids) for tok_ids in expert_tok_ids] == recv_expert_count.tolist()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
@ -335,3 +395,4 @@ if __name__ == "__main__":
|
|||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
run_high_throughput()
|
run_high_throughput()
|
||||||
|
run_low_latency()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user