mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 05:05:01 +08:00
[feat] enable SM100 CUTLASS block scaled group gemm for smaller batch sizes (#20640)
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
This commit is contained in:
parent
34dad19e7b
commit
97abeb1daa
@ -522,16 +522,14 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
|||||||
return out.to(dtype=out_dtype)
|
return out.to(dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor,
|
def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor) -> bool:
|
w2: torch.Tensor) -> bool:
|
||||||
|
|
||||||
def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int):
|
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
|
||||||
return M >= 128 and N % 128 == 0 and K % 128 == 0
|
return N % 128 == 0 and K % 128 == 0
|
||||||
|
|
||||||
m = hidden_states.size(0)
|
|
||||||
_, K, N = w2.size()
|
_, K, N = w2.size()
|
||||||
if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K):
|
if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"CutlassBlockScaledGroupedGemm disabled: unalinged problem size.")
|
"CutlassBlockScaledGroupedGemm disabled: unalinged problem size.")
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -1180,7 +1180,7 @@ def fused_experts(
|
|||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
|
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
|
||||||
and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)):
|
and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)):
|
||||||
assert apply_router_weight_on_input is False
|
assert apply_router_weight_on_input is False
|
||||||
return run_cutlass_block_scaled_fused_experts(
|
return run_cutlass_block_scaled_fused_experts(
|
||||||
a=hidden_states,
|
a=hidden_states,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user