From 19d98e0c7db96713f0e2201649159431177a56e2 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 3 Mar 2025 16:29:53 -0500 Subject: [PATCH] [Kernel] Optimize moe intermediate_cache usage (#13625) Signed-off-by: mgoin --- .../layers/fused_moe/fused_moe.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 00260313e72eb..5336b3c100235 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1240,15 +1240,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) - intermediate_cache1 = torch.empty((M, top_k_num, N), - device=hidden_states.device, - dtype=hidden_states.dtype) + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache1 = cache13[:M * top_k_num * N].view( + (M, topk_ids.shape[1], N)) + intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view( + (M, topk_ids.shape[1], w2.shape[1])) + + # This needs separate memory since it's used concurrently with cache1 intermediate_cache2 = torch.empty((M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16