mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:25:34 +08:00
[feat] enable SM100 CUTLASS block scaled group gemm for smaller batch sizes (#20640)
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
This commit is contained in:
parent
34dad19e7b
commit
97abeb1daa
@ -522,16 +522,14 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
return out.to(dtype=out_dtype)
|
||||
|
||||
|
||||
def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
|
||||
w2: torch.Tensor) -> bool:
|
||||
|
||||
def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int):
|
||||
return M >= 128 and N % 128 == 0 and K % 128 == 0
|
||||
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
|
||||
return N % 128 == 0 and K % 128 == 0
|
||||
|
||||
m = hidden_states.size(0)
|
||||
_, K, N = w2.size()
|
||||
if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K):
|
||||
if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K):
|
||||
logger.debug(
|
||||
"CutlassBlockScaledGroupedGemm disabled: unalinged problem size.")
|
||||
return False
|
||||
|
||||
@ -1180,7 +1180,7 @@ def fused_experts(
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
|
||||
and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)):
|
||||
and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)):
|
||||
assert apply_router_weight_on_input is False
|
||||
return run_cutlass_block_scaled_fused_experts(
|
||||
a=hidden_states,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user