mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 22:54:34 +08:00
[Bugfix] Revert MoE Triton Config Default (#12629)
SUMMARY: * previous PR for pulling in block configs also changed defaults (https://github.com/vllm-project/vllm/pull/11589/files) for FP8 * this broke L4 MoE since there was not enough SHM for the default configuration * this reverts the non-block example to the default Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
parent
415f19474d
commit
145c2ff648
@ -660,36 +660,17 @@ def get_default_config(
|
||||
is_marlin: bool,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
) -> Dict[str, int]:
|
||||
if dtype == "fp8_w8a8":
|
||||
if block_shape is None:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4,
|
||||
}
|
||||
if M <= E:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4,
|
||||
}
|
||||
else:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_shape[0],
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_shape[0],
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user