# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import glob import itertools import os import subprocess import sys import jinja2 ARCHS = [] SUPPORT_FP8 = False for arch in sys.argv[1].split(","): arch = arch[: arch.index(".") + 2].replace(".", "") arch = int(arch) # only SM89 and SM120 fully support # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. # SM90 and SM100 can use this PTX, but it’s simulated # with FP16 MMA, so it cannot achieve any acceleration. if arch in [89, 120]: SUPPORT_FP8 = True FILE_HEAD_COMMENT = """ // auto generated by generate_kernels.py // clang-format off """.lstrip() FILE_HEAD = ( FILE_HEAD_COMMENT + """ #include "kernel.h" #include "marlin_template.h" namespace MARLIN_NAMESPACE_NAME { """ ) TEMPLATE = ( "template __global__ void Marlin<" "{{a_type_id}}, " "{{b_type_id}}, " "{{c_type_id}}, " "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " "{{thread_k_blocks}}, " "{{m_block_size_8}}, " "{{stages}}, " "{{group_blocks}}, " "{{is_zp_float}}>" "( MARLIN_KERNEL_PARAMS );" ) THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] QUANT_CONFIGS = [ # AWQ-INT4 { "b_type": "kU4", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": THREAD_M_BLOCKS, "group_blocks": [-1, 2, 4, 8], }, # HQQ { "a_type": ["kFloat16"], "b_type": "kU4", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": THREAD_M_BLOCKS, "group_blocks": [4], "is_zp_float": True, }, # GPTQ-INT4 { "b_type": "kU4B8", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": THREAD_M_BLOCKS, "group_blocks": [-1, 0, 2, 4, 8], }, # GPTQ-INT8 { "b_type": "kU8B128", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": THREAD_M_BLOCKS, "group_blocks": [-1, 0, 2, 4, 8], }, # FP8 { "b_type": "kFE4M3fn", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": THREAD_M_BLOCKS, "group_blocks": [-1, 8], }, # NVFP4 { "b_type": "kFE2M1f", "s_type": "kFE4M3fn", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": THREAD_M_BLOCKS, "group_blocks": [1], }, # MXFP4 { "a_type": ["kBFloat16"], "b_type": "kFE2M1f", "s_type": "kFE8M0fnu", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": THREAD_M_BLOCKS, "group_blocks": [2], }, # AWQ-INT4 with INT8 activation { "a_type": ["kS8"], "b_type": "kU4", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": [1, 2, 3, 4], "group_blocks": [-1, 2, 4, 8], }, # GPTQ-INT4 with INT8 activation { "a_type": ["kS8"], "b_type": "kU4B8", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": [1, 2, 3, 4], "group_blocks": [-1, 2, 4, 8], }, # GPTQ-INT4 with FP8 activation { "a_type": ["kFE4M3fn"], "b_type": "kU4B8", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": [1, 2, 3, 4], "group_blocks": [-1, 2, 4, 8], }, # AWQ-INT4 with FP8 activation { "a_type": ["kFE4M3fn"], "b_type": "kU4", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": [1, 2, 3, 4], "group_blocks": [-1, 2, 4, 8], }, # MXFP4 with FP8 activation { "a_type": ["kFE4M3fn"], "b_type": "kFE2M1f", "c_type": ["kBFloat16"], "s_type": "kFE8M0fnu", "thread_configs": THREAD_CONFIGS, "thread_m_blocks": [1, 2, 3, 4], "group_blocks": [2], }, ] def remove_old_kernels(): for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"): subprocess.call(["rm", "-f", filename]) filename = os.path.dirname(__file__) + "/kernel_selector.h" subprocess.call(["rm", "-f", filename]) def generate_new_kernels(): result_dict = {} for quant_config in QUANT_CONFIGS: c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"]) b_type = quant_config["b_type"] is_zp_float = quant_config.get("is_zp_float", False) all_group_blocks = quant_config["group_blocks"] all_m_blocks = quant_config["thread_m_blocks"] all_thread_configs = quant_config["thread_configs"] for a_type, c_type in itertools.product(a_types, c_types): if not SUPPORT_FP8 and a_type == "kFE4M3fn": continue if "16" in a_type and "16" in c_type and a_type != c_type: continue s_type = quant_config.get("s_type", c_type) if (a_type, b_type, c_type) not in result_dict: result_dict[(a_type, b_type, c_type)] = [] for group_blocks, m_blocks, thread_configs in itertools.product( all_group_blocks, all_m_blocks, all_thread_configs ): thread_k, thread_n, threads = thread_configs if threads == 256: # for small batch (m_blocks == 1), # we only need (128, 128, 256) # for large batch (m_blocks > 1), # we only need (64, 256, 256) if m_blocks <= 1 and (thread_k, thread_n) != (128, 128): continue if m_blocks > 1 and (thread_k, thread_n) != (64, 256): continue config = { "threads": threads, "s_type": s_type, "thread_m_blocks": max(m_blocks, 1), "thread_k_blocks": thread_k // 16, "thread_n_blocks": thread_n // 16, "m_block_size_8": "true" if m_blocks == 0.5 else "false", "stages": "pipe_stages", "group_blocks": group_blocks, "is_zp_float": "true" if is_zp_float else "false", } result_dict[(a_type, b_type, c_type)].append(config) kernel_selector_str = FILE_HEAD_COMMENT for (a_type, b_type, c_type), config_list in result_dict.items(): all_template_str_list = [] for config in config_list: s_type = config["s_type"] template_str = jinja2.Template(TEMPLATE).render( a_type_id=f"vllm::{a_type}.id()", b_type_id=f"vllm::{b_type}.id()", c_type_id=f"vllm::{c_type}.id()", s_type_id=f"vllm::{s_type}.id()", **config, ) all_template_str_list.append(template_str) conditions = [ f"a_type == vllm::{a_type}", f"b_type == vllm::{b_type}", f"c_type == vllm::{c_type}", f"s_type == vllm::{s_type}", f"threads == {config['threads']}", f"thread_m_blocks == {config['thread_m_blocks']}", f"thread_n_blocks == {config['thread_n_blocks']}", f"thread_k_blocks == {config['thread_k_blocks']}", f"m_block_size_8 == {config['m_block_size_8']}", f"group_blocks == {config['group_blocks']}", f"is_zp_float == {config['is_zp_float']}", ] conditions = " && ".join(conditions) if kernel_selector_str == FILE_HEAD_COMMENT: kernel_selector_str += f"if ({conditions})\n kernel = " else: kernel_selector_str += f"else if ({conditions})\n kernel = " kernel_template2 = ( "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " "{{thread_n_blocks}}, {{thread_k_blocks}}, " "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " "{{is_zp_float}}>;" ) kernel_selector_str += ( jinja2.Template(kernel_template2).render( a_type_id=f"vllm::{a_type}.id()", b_type_id=f"vllm::{b_type}.id()", c_type_id=f"vllm::{c_type}.id()", s_type_id=f"vllm::{s_type}.id()", **config, ) + "\n" ) file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" if a_type == "kFE4M3fn": filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" else: filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" filename = filename.lower() with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: kernel_selector_str += ( "else if (a_type == vllm::kFE4M3fn)\n" " TORCH_CHECK(false, " '"marlin kernel with fp8 activation is not built.");' ) with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f: f.write(kernel_selector_str) if __name__ == "__main__": remove_old_kernels() generate_new_kernels()