mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 14:27:02 +08:00
refactor and add low latency code
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
e3e2bb3865
commit
48dcc72d7e
@ -139,15 +139,9 @@ def combine_forward(
|
|||||||
return combined_x, event
|
return combined_x, event
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def run_high_throughput():
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend="nccl",
|
|
||||||
init_method="env://",
|
|
||||||
)
|
|
||||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
||||||
rank = int(os.environ.get("RANK", 0))
|
rank = int(os.environ.get("RANK", 0))
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
|
|
||||||
group = dist.group.WORLD
|
group = dist.group.WORLD
|
||||||
num_experts = 8
|
num_experts = 8
|
||||||
@ -182,11 +176,6 @@ if __name__ == "__main__":
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
num_experts,
|
num_experts,
|
||||||
)
|
)
|
||||||
recv_topk_global_idx = recv_topk_idx + torch.where(
|
|
||||||
recv_topk_idx == -1,
|
|
||||||
0,
|
|
||||||
rank * local_num_experts,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dispatch naive
|
# Dispatch naive
|
||||||
all_x = [torch.empty_like(x) for _ in range(world_size)]
|
all_x = [torch.empty_like(x) for _ in range(world_size)]
|
||||||
@ -249,3 +238,100 @@ if __name__ == "__main__":
|
|||||||
):
|
):
|
||||||
appear_count += 1.0
|
appear_count += 1.0
|
||||||
torch.testing.assert_close(combined_results[i], x[i] * appear_count)
|
torch.testing.assert_close(combined_results[i], x[i] * appear_count)
|
||||||
|
|
||||||
|
|
||||||
|
# You may call this function at the framework initialization
|
||||||
|
def low_latency_get_buffer(
|
||||||
|
group: dist.ProcessGroup,
|
||||||
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
|
hidden: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> Buffer:
|
||||||
|
# NOTES: the low-latency mode will consume much more space than the normal mode
|
||||||
|
# So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size
|
||||||
|
# in the decoding engine) should be less than 256
|
||||||
|
global _buffer
|
||||||
|
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
|
||||||
|
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate a buffer if not existed or not enough buffer size
|
||||||
|
if (
|
||||||
|
_buffer is None
|
||||||
|
or _buffer.group != group
|
||||||
|
or not _buffer.low_latency_mode
|
||||||
|
or _buffer.num_rdma_bytes < num_rdma_bytes
|
||||||
|
):
|
||||||
|
# NOTES: for the best performance, the QP number **must** be equal to the
|
||||||
|
# number of the local experts
|
||||||
|
assert num_experts % group.size() == 0
|
||||||
|
_buffer = Buffer(
|
||||||
|
group,
|
||||||
|
0,
|
||||||
|
num_rdma_bytes,
|
||||||
|
low_latency_mode=True,
|
||||||
|
num_qps_per_rank=num_experts // group.size(),
|
||||||
|
)
|
||||||
|
return _buffer
|
||||||
|
|
||||||
|
|
||||||
|
def low_latency_dispatch(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
|
num_experts: int,
|
||||||
|
):
|
||||||
|
global _buffer
|
||||||
|
|
||||||
|
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer
|
||||||
|
# status once you replay)
|
||||||
|
recv_hidden_states, recv_expert_count, handle, event, hook = (
|
||||||
|
_buffer.low_latency_dispatch(
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
num_max_dispatch_tokens_per_rank,
|
||||||
|
num_experts,
|
||||||
|
async_finish=False,
|
||||||
|
return_recv_hook=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTES: the actual tensor will not be received only if you call `hook()`,
|
||||||
|
# it is useful for double-batch overlapping, but **without any SM occupation**
|
||||||
|
# If you don't want to overlap, please set `return_recv_hook=False`
|
||||||
|
# Later, you can use our GEMM library to do the computation with this specific
|
||||||
|
# format
|
||||||
|
return recv_hidden_states, recv_expert_count, handle, event, hook
|
||||||
|
|
||||||
|
|
||||||
|
def low_latency_combine(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
handle: tuple,
|
||||||
|
):
|
||||||
|
global _buffer
|
||||||
|
|
||||||
|
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer
|
||||||
|
# status once you replay)
|
||||||
|
combined_hidden_states, event_overlap, hook = _buffer.low_latency_combine(
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
handle,
|
||||||
|
async_finish=False,
|
||||||
|
return_recv_hook=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTES: the same behavior as described in the dispatch kernel
|
||||||
|
return combined_hidden_states, event_overlap, hook
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
init_method="env://",
|
||||||
|
)
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
run_high_throughput()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user