[Model] Enable FP8 QKV in MoE and refine kernel tuning script (#5039)

This commit is contained in:
Cody Yu 2024-05-31 14:29:19 -07:00 committed by GitHub
parent a377f0bd5e
commit e9899fb7a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 714 additions and 117 deletions

View File

@ -11,25 +11,36 @@ from tqdm import tqdm
from vllm.model_executor.layers.fused_moe import (fused_moe, from vllm.model_executor.layers.fused_moe import (fused_moe,
get_config_file_name) get_config_file_name)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def main(model, tp_size, gpu, dtype: str):
def main(dtype: str): os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
method = fused_moe method = fused_moe
for bs in [ for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096 2048, 3072, 4096
]: ]:
run_grid(bs, method=method, dtype=dtype) run_grid(bs,
model=model,
method=method,
gpu=gpu,
tp_size=tp_size,
dtype=dtype)
def run_grid(bs, method, dtype: str): def run_grid(bs, model, method, gpu, tp_size, dtype: str):
d_model = 4096 if model == '8x7B':
d_model = 4096
model_intermediate_size = 14336
num_layers = 32
elif model == '8x22B':
d_model = 6144
model_intermediate_size = 16384
num_layers = 56
else:
raise ValueError(f'Unsupported Mixtral model {model}')
num_total_experts = 8 num_total_experts = 8
top_k = 2 top_k = 2
tp_size = 2 # tp_size = 2
model_intermediate_size = 14336
num_layers = 32
num_calls = 100 num_calls = 100
num_warmup_trials = 1 num_warmup_trials = 1
@ -211,5 +222,18 @@ if __name__ == "__main__":
choices=['float8', 'float16'], choices=['float8', 'float16'],
help='Data type used for fused_moe kernel computations', help='Data type used for fused_moe kernel computations',
) )
parser.add_argument('--model',
type=str,
default='8x7B',
choices=['8x7B', '8x22B'],
help='The Mixtral model to benchmark')
parser.add_argument('--tp-size',
type=int,
default=2,
help='Tensor paralleli size')
parser.add_argument('--gpu',
type=int,
default=0,
help="GPU ID for benchmarking")
args = parser.parse_args() args = parser.parse_args()
sys.exit(main(args.dtype)) sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype))

View File

@ -0,0 +1,138 @@
{
"1": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
}
}

View File

@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
}
}

View File

@ -3,61 +3,59 @@
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1 "GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}, },
"2": { "2": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1 "GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
}, },
"4": { "4": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1 "GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}, },
"8": { "8": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 32,
"num_warps": 8, "num_warps": 8,
"num_stages": 5 "num_stages": 2
}, },
"16": { "16": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 8,
"num_stages": 5 "num_stages": 5
}, },
"24": { "24": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16, "GROUP_SIZE_M": 64,
"num_warps": 8, "num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4 "num_stages": 4
}, },
"48": { "48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
@ -65,37 +63,45 @@
"num_warps": 4, "num_warps": 4,
"num_stages": 4 "num_stages": 4
}, },
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": { "96": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 32,
"num_warps": 8, "num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5 "num_stages": 5
}, },
"512": { "256": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 8,
"num_stages": 2 "num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
}, },
"1024": { "1024": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
@ -109,7 +115,7 @@
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 64,
"num_warps": 8, "num_warps": 8,
"num_stages": 4 "num_stages": 4
}, },
@ -125,7 +131,7 @@
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 16,
"num_warps": 8, "num_warps": 8,
"num_stages": 4 "num_stages": 4
}, },

View File

@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@ -2,104 +2,104 @@
"1": { "1": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 4 "num_stages": 5
}, },
"2": { "2": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 32,
"num_warps": 4, "num_warps": 8,
"num_stages": 4 "num_stages": 4
}, },
"4": { "4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 4 "num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}, },
"16": { "16": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 4 "num_stages": 4
}, },
"24": { "24": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 4 "num_stages": 5
}, },
"32": { "32": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16, "GROUP_SIZE_M": 64,
"num_warps": 8, "num_warps": 4,
"num_stages": 4 "num_stages": 4
}, },
"48": { "48": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 4 "num_stages": 3
}, },
"64": { "64": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 4 "num_stages": 4
}, },
"96": { "96": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 4 "num_stages": 4
}, },
"128": { "128": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 4 "num_stages": 3
}, },
"256": { "256": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 8,
"num_stages": 4 "num_stages": 5
}, },
"512": { "512": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16, "GROUP_SIZE_M": 64,
"num_warps": 8, "num_warps": 8,
"num_stages": 4 "num_stages": 4
}, },
@ -115,7 +115,7 @@
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 64,
"num_warps": 8, "num_warps": 8,
"num_stages": 4 "num_stages": 4
}, },
@ -139,7 +139,7 @@
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64, "GROUP_SIZE_M": 16,
"num_warps": 8, "num_warps": 8,
"num_stages": 4 "num_stages": 4
} }

View File

@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
}
}

View File

@ -278,15 +278,6 @@ class MixtralAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
if isinstance(
quant_config,
Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
print_warning_once(
"For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved."
)
quant_config = None
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,