[Misc] support model prefix & add deepseek vl2 tiny fused moe config (#17763)

Signed-off-by: 唯勤 <xsank.mz@alibaba-inc.com>
Co-authored-by: 唯勤 <xsank.mz@alibaba-inc.com>
This commit is contained in:
xsank 2025-05-08 15:50:22 +08:00 committed by GitHub
parent 39956efb3f
commit 0a9bbaa104
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 161 additions and 9 deletions

View File

@ -6,15 +6,16 @@ import time
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from itertools import product from itertools import product
from types import SimpleNamespace
from typing import Any, TypedDict from typing import Any, TypedDict
import ray import ray
import torch import torch
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -534,8 +535,12 @@ def get_weight_block_size_safety(config, default_value=None):
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
config = AutoConfig.from_pretrained( config = get_config(model=args.model,
args.model, trust_remote_code=args.trust_remote_code) trust_remote_code=args.trust_remote_code)
if args.model_prefix:
config = getattr(config, args.model_prefix)
config = SimpleNamespace(**config)
if config.architectures[0] == "DbrxForCausalLM": if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
@ -546,15 +551,14 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif (config.architectures[0] == "DeepseekV3ForCausalLM" elif (config.architectures[0]
or config.architectures[0] == "DeepseekV2ForCausalLM"): in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM")):
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [ elif config.architectures[0] in ("Qwen2MoeForCausalLM",
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" "Qwen3MoeForCausalLM"):
]:
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
@ -569,7 +573,8 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype dtype = torch.float16 if current_platform.is_rocm() else getattr(
torch, config.torch_dtype)
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config) block_quant_shape = get_weight_block_size_safety(config)
@ -659,6 +664,7 @@ if __name__ == "__main__":
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--model-prefix", type=str, required=False)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
}
}