From 7a888271f5bd401f8fc64704c239833244471a91 Mon Sep 17 00:00:00 2001 From: Wes Date: Wed, 26 Mar 2025 17:21:34 -0600 Subject: [PATCH] Use Cache Hinting for fused_moe kernel (#15511) --- .../model_executor/layers/fused_moe/fused_moe.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 97e915c60335a..faaea6b4de972 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -189,7 +189,11 @@ def fused_moe_kernel_gptq_awq( mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0) - b = tl.load(b_ptrs) + b = tl.load( + b_ptrs, + cache_modifier=".cg", + eviction_policy="evict_last", + ) if use_int4_w4a16: b = (b >> b_shifter) & 0xF @@ -391,9 +395,13 @@ def fused_moe_kernel( mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + cache_modifier=".cg", + eviction_policy="evict_last", + ) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)