diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0929530ebec4..70d0037d7cb0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -189,11 +189,7 @@ def fused_moe_kernel_gptq_awq( mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0) - b = tl.load( - b_ptrs, - cache_modifier=".cg", - eviction_policy="evict_last", - ) + b = tl.load(b_ptrs) if use_int4_w4a16: b = (b >> b_shifter) & 0xF @@ -395,13 +391,9 @@ def fused_moe_kernel( mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0) - b = tl.load( - b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0, - cache_modifier=".cg", - eviction_policy="evict_last", - ) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)