mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:27:19 +08:00
refactor and add low latency code
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
e3e2bb3865
commit
48dcc72d7e
@ -139,15 +139,9 @@ def combine_forward(
|
||||
return combined_x, event
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="env://",
|
||||
)
|
||||
def run_high_throughput():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
group = dist.group.WORLD
|
||||
num_experts = 8
|
||||
@ -182,11 +176,6 @@ if __name__ == "__main__":
|
||||
topk_weights,
|
||||
num_experts,
|
||||
)
|
||||
recv_topk_global_idx = recv_topk_idx + torch.where(
|
||||
recv_topk_idx == -1,
|
||||
0,
|
||||
rank * local_num_experts,
|
||||
)
|
||||
|
||||
# Dispatch naive
|
||||
all_x = [torch.empty_like(x) for _ in range(world_size)]
|
||||
@ -249,3 +238,100 @@ if __name__ == "__main__":
|
||||
):
|
||||
appear_count += 1.0
|
||||
torch.testing.assert_close(combined_results[i], x[i] * appear_count)
|
||||
|
||||
|
||||
# You may call this function at the framework initialization
|
||||
def low_latency_get_buffer(
|
||||
group: dist.ProcessGroup,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
hidden: int,
|
||||
num_experts: int,
|
||||
) -> Buffer:
|
||||
# NOTES: the low-latency mode will consume much more space than the normal mode
|
||||
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size
|
||||
# in the decoding engine) should be less than 256
|
||||
global _buffer
|
||||
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
|
||||
)
|
||||
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
if (
|
||||
_buffer is None
|
||||
or _buffer.group != group
|
||||
or not _buffer.low_latency_mode
|
||||
or _buffer.num_rdma_bytes < num_rdma_bytes
|
||||
):
|
||||
# NOTES: for the best performance, the QP number **must** be equal to the
|
||||
# number of the local experts
|
||||
assert num_experts % group.size() == 0
|
||||
_buffer = Buffer(
|
||||
group,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_experts // group.size(),
|
||||
)
|
||||
return _buffer
|
||||
|
||||
|
||||
def low_latency_dispatch(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
num_experts: int,
|
||||
):
|
||||
global _buffer
|
||||
|
||||
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer
|
||||
# status once you replay)
|
||||
recv_hidden_states, recv_expert_count, handle, event, hook = (
|
||||
_buffer.low_latency_dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
)
|
||||
|
||||
# NOTES: the actual tensor will not be received only if you call `hook()`,
|
||||
# it is useful for double-batch overlapping, but **without any SM occupation**
|
||||
# If you don't want to overlap, please set `return_recv_hook=False`
|
||||
# Later, you can use our GEMM library to do the computation with this specific
|
||||
# format
|
||||
return recv_hidden_states, recv_expert_count, handle, event, hook
|
||||
|
||||
|
||||
def low_latency_combine(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
handle: tuple,
|
||||
):
|
||||
global _buffer
|
||||
|
||||
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer
|
||||
# status once you replay)
|
||||
combined_hidden_states, event_overlap, hook = _buffer.low_latency_combine(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
handle,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
# NOTES: the same behavior as described in the dispatch kernel
|
||||
return combined_hidden_states, event_overlap, hook
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="env://",
|
||||
)
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
run_high_throughput()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user