mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 12:24:56 +08:00
[Log] DeepGEMM Update Log for Unaligned Problem Size (#22208)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
6fa41e0c32
commit
d7b28f3415
@ -33,7 +33,7 @@ def deep_gemm_block_shape() -> list[int]:
|
||||
return [block, block]
|
||||
|
||||
|
||||
def _valid_deep_gemm_shape(M: int, N: int, K: int):
|
||||
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
|
||||
align = deep_gemm_block_shape()[0]
|
||||
return align <= M and N % align == 0 and K % align == 0
|
||||
|
||||
@ -51,9 +51,26 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
|
||||
M = hidden_states.size(0)
|
||||
_, K, N = w2.size()
|
||||
|
||||
align = deep_gemm_block_shape()[0]
|
||||
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
logger.debug_once(
|
||||
"DeepGemm disabled: unaligned problem size. M: %s, N: %s, K: %s",
|
||||
"DeepGemm disabled due to unaligned problem size. "
|
||||
"M: %s, N: %s, K: %s. M should >= align size "
|
||||
"and N and K must be multiples of %s."
|
||||
"This is not an error and we will fall back to triton.",
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
align,
|
||||
)
|
||||
return False
|
||||
elif N <= 512:
|
||||
logger.debug_once(
|
||||
"DeepGemm disabled for N <= 512. M: %s, N: %s, K: %s. "
|
||||
"This means we will fallback to triton "
|
||||
"for this specific shape for further speed up.",
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
|
||||
@ -1360,10 +1360,8 @@ def fused_experts(
|
||||
# E8M0 scale, which means we requantize the weight and input to the specific
|
||||
# scale. Fallen back to cutlass or triton for some cases would cause
|
||||
# accuracy issue.
|
||||
N = w1.size(1)
|
||||
should_use_deep_gemm = ((N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2))
|
||||
or is_blackwell_deep_gemm_used())
|
||||
should_use_deep_gemm = is_blackwell_deep_gemm_used() or _valid_deep_gemm(
|
||||
hidden_states, w1, w2)
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
|
||||
assert apply_router_weight_on_input is False
|
||||
assert is_act_and_mul, (
|
||||
|
||||
@ -107,8 +107,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm and (_valid_deep_gemm_shape(M, N, K)
|
||||
or is_blackwell_deep_gemm_used()):
|
||||
if self.allow_deep_gemm and (is_blackwell_deep_gemm_used()
|
||||
or _valid_deep_gemm_shape(M, N, K)):
|
||||
assert self.deep_gemm_expert is not None
|
||||
return self.deep_gemm_expert.workspace_shapes(
|
||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user