[Kernel] Optimize moe intermediate_cache usage (#13625)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-03-03 16:29:53 -05:00 committed by GitHub
parent 2b04c209ee
commit 19d98e0c7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1240,15 +1240,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config = get_config_func(M)
intermediate_cache1 = torch.empty((M, top_k_num, N),
device=hidden_states.device,
dtype=hidden_states.dtype)
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = cache13[:M * top_k_num * N].view(
(M, topk_ids.shape[1], N))
intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view(
(M, topk_ids.shape[1], w2.shape[1]))
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16