mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 02:06:31 +08:00
[Kernel] Optimize moe intermediate_cache usage (#13625)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
2b04c209ee
commit
19d98e0c7d
@ -1240,15 +1240,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
||||
config = get_config_func(M)
|
||||
|
||||
intermediate_cache1 = torch.empty((M, top_k_num, N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache1 = cache13[:M * top_k_num * N].view(
|
||||
(M, topk_ids.shape[1], N))
|
||||
intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view(
|
||||
(M, topk_ids.shape[1], w2.shape[1]))
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user