Use Cache Hinting for fused_moe kernel (#15511)

This commit is contained in:
Wes 2025-03-26 17:21:34 -06:00 committed by GitHub
parent 9d119a86ae
commit 7a888271f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -189,7 +189,11 @@ def fused_moe_kernel_gptq_awq(
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs)
b = tl.load(
b_ptrs,
cache_modifier=".cg",
eviction_policy="evict_last",
)
if use_int4_w4a16:
b = (b >> b_shifter) & 0xF
@ -391,9 +395,13 @@ def fused_moe_kernel(
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(
b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0,
cache_modifier=".cg",
eviction_policy="evict_last",
)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)