diff --git a/.buildkite/performance-benchmarks/README.md b/.buildkite/performance-benchmarks/README.md index 6d494f64f14fa..015f48c2520d6 100644 --- a/.buildkite/performance-benchmarks/README.md +++ b/.buildkite/performance-benchmarks/README.md @@ -108,6 +108,65 @@ The number of this test is less stable compared to the delay and latency benchma WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. +#### Default Parameters Field + +We can specify default parameters in a JSON field with key `defaults`. Parameters defined in the field are applied globally to all serving tests, and can be overridden in test case fields. Here is an example: + +
+ An Example of default parameters field + +```json +{ + "defaults": { + "qps_list": [ + "inf" + ], + "server_environment_variables": { + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1 + }, + "server_parameters": { + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "block_size": 128, + "disable_log_stats": "", + "load_format": "dummy" + }, + "client_parameters": { + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "num_prompts": 200, + "ignore-eos": "" + } + }, + "tests": [ + { + "test_name": "serving_llama3B_tp2_random_128_128", + "server_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + "tensor_parallel_size": 2, + }, + "client_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + } + }, + { + "test_name": "serving_qwen3_tp4_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-14B", + "tensor_parallel_size": 4, + }, + "client_parameters": { + "model": "Qwen/Qwen3-14B", + } + }, + ] +} +``` + +
+ ### Visualizing the results The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](performance-benchmarks-descriptions.md) with real benchmarking results. diff --git a/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh index 99a5a5e334f8e..34ceefe0996f2 100644 --- a/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh +++ b/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh @@ -110,7 +110,8 @@ json2envs() { wait_for_server() { # wait for vllm server to start # return 1 if vllm server crashes - timeout 1200 bash -c ' + local timeout_val="1200" + timeout "$timeout_val" bash -c ' until curl -X POST localhost:8000/v1/completions; do sleep 1 done' && return 0 || return 1 @@ -316,12 +317,44 @@ run_throughput_tests() { run_serving_tests() { # run serving tests using `vllm bench serve` command # $1: a json file specifying serving test cases + # + # Supported JSON formats: + # 1) Plain format: top-level array + # [ { "test_name": "...", "server_parameters": {...}, ... }, ... ] + # + # 2) Default parameters field + plain format tests + # { + # "defaults": { ... }, + # "tests": [ { "test_name": "...", "server_parameters": {...}, ... }, ... ] + # } local serving_test_file serving_test_file=$1 # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do + jq -c ' + if type == "array" then + # Plain format: test cases array + .[] + elif (type == "object" and has("tests")) then + # merge the default parameters into each test cases + . as $root + | ($root.defaults // {}) as $d + | ($root.tests // [])[] + # default qps / max_concurrency from defaults if missing + | .qps_list = (.qps_list // $d.qps_list) + | .max_concurrency_list = (.max_concurrency_list // $d.max_concurrency_list) + # merge envs / params: test overrides defaults + | .server_environment_variables = + (($d.server_environment_variables // {}) + (.server_environment_variables // {})) + | .server_parameters = + (($d.server_parameters // {}) + (.server_parameters // {})) + | .client_parameters = + (($d.client_parameters // {}) + (.client_parameters // {})) + else + error("Unsupported serving test file format: must be array or object with .tests") + end + ' "$serving_test_file" | while read -r params; do # get the test name, and append the GPU type back to it. test_name=$(echo "$params" | jq -r '.test_name') if [[ ! "$test_name" =~ ^serving_ ]]; then @@ -335,20 +368,25 @@ run_serving_tests() { continue fi - # get client and server arguments + # get client and server arguments (after merged the default parameters) server_params=$(echo "$params" | jq -r '.server_parameters') server_envs=$(echo "$params" | jq -r '.server_environment_variables') client_params=$(echo "$params" | jq -r '.client_parameters') + server_args=$(json2args "$server_params") server_envs=$(json2envs "$server_envs") client_args=$(json2args "$client_params") + + # qps_list qps_list=$(echo "$params" | jq -r '.qps_list') qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') echo "Running over qps list $qps_list" + + # max_concurrency_list (fallback to num_prompts if missing) max_concurrency_list=$(echo "$params" | jq -r '.max_concurrency_list') if [[ -z "$max_concurrency_list" || "$max_concurrency_list" == "null" ]]; then - num_prompts=$(echo "$client_params" | jq -r '.num_prompts') - max_concurrency_list="[$num_prompts]" + num_prompts=$(echo "$client_params" | jq -r '.num_prompts') + max_concurrency_list="[$num_prompts]" fi max_concurrency_list=$(echo "$max_concurrency_list" | jq -r '.[] | @sh') echo "Running over max concurrency list $max_concurrency_list" diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc2.json b/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc2.json deleted file mode 100644 index f758097e098e4..0000000000000 --- a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc2.json +++ /dev/null @@ -1,610 +0,0 @@ -[ - { - "test_name": "serving_llama8B_bf16_tp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - } -] diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc3.json b/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc3.json deleted file mode 100644 index 0b1a42e790255..0000000000000 --- a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc3.json +++ /dev/null @@ -1,1023 +0,0 @@ -[ - { - "test_name": "serving_llama8B_bf16_pp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_pp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_pp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp2pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_pp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp2pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_pp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp2pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_pp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp2pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - } -] diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json b/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json index f792956f39472..8f7200862d20c 100644 --- a/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json +++ b/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json @@ -1,276 +1,246 @@ -[ - { - "test_name": "serving_llama8B_tp1_sharegpt", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 32 - } +{ + "defaults": { + "qps_list": [ + "inf" + ], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 }, - { - "test_name": "serving_llama8B_tp2_sharegpt", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 32 - } + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" }, - { - "test_name": "serving_llama8B_tp1_random_128_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } - }, - { - "test_name": "serving_llama8B_tp2_random_128_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } - }, - { - "test_name": "serving_llama8B_tp1_random_128_2048", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 2048, - "ignore-eos": "", - "num_prompts": 32 - } - }, - { - "test_name": "serving_llama8B_tp2_random_128_2048", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 2048, - "ignore-eos": "", - "num_prompts": 32 - } - }, - { - "test_name": "serving_llama8B_tp1_random_2048_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 2048, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } - }, - { - "test_name": "serving_llama8B_tp2_random_2048_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 2048, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "ignore-eos": "", + "num_prompts": 200 } -] + }, + "tests": [ + { + "test_name": "serving_llama8B_tp1_sharegpt", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json" + } + }, + { + "test_name": "serving_llama8B_tp2_sharegpt", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json" + } + }, + { + "test_name": "serving_llama8B_tp1_random_128_128", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp2_random_128_128", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp4_random_128_128", + "server_parameters": { + "tensor_parallel_size": 4 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp1_random_128_2048", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048 + } + }, + { + "test_name": "serving_llama8B_tp2_random_128_2048", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048 + } + }, + { + "test_name": "serving_llama8B_tp4_random_128_2048", + "server_parameters": { + "tensor_parallel_size": 4 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048 + } + }, + { + "test_name": "serving_llama8B_tp1_random_2048_128", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 2048, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp2_random_2048_128", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 2048, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp4_random_2048_128", + "server_parameters": { + "tensor_parallel_size": 4 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 2048, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama3B_tp1_random_128_128", + "server_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_granite2B_tp1_random_128_128", + "server_parameters": { + "model": "ibm-granite/granite-3.2-2b-instruct", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "ibm-granite/granite-3.2-2b-instruct", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_qwen1.7B_tp1_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-1.7B", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "Qwen/Qwen3-1.7B", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_qwen4B_tp1_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-4B", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "Qwen/Qwen3-4B", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_qwen8B_tp1_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-8B", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "Qwen/Qwen3-8B", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_glm9B_tp1_random_128_128", + "server_parameters": { + "model": "zai-org/glm-4-9b-hf", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "zai-org/glm-4-9b-hf", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_gemma7B_tp1_random_128_128", + "server_parameters": { + "model": "google/gemma-7b", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "google/gemma-7b", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + } + ] +} diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 38c400ba1faf5..fbfc923998f89 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -8,7 +8,7 @@ steps: commands: # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" @@ -30,19 +30,6 @@ steps: DOCKER_BUILDKIT: "1" # x86 + CUDA builds - - label: "Build wheel - CUDA 12.8" - depends_on: ~ - id: build-wheel-cuda-12-8 - agents: - queue: cpu_queue_postmerge - commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - - "mkdir artifacts" - - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - - "bash .buildkite/scripts/upload-wheels.sh" - env: - DOCKER_BUILDKIT: "1" - - label: "Build wheel - CUDA 12.9" depends_on: ~ id: build-wheel-cuda-12-9 @@ -109,7 +96,6 @@ steps: - label: "Annotate release workflow" depends_on: - create-multi-arch-manifest - - build-wheel-cuda-12-8 id: annotate-release-workflow agents: queue: cpu_queue_postmerge diff --git a/.buildkite/scripts/generate-nightly-index.py b/.buildkite/scripts/generate-nightly-index.py new file mode 100644 index 0000000000000..8d09ba178db7b --- /dev/null +++ b/.buildkite/scripts/generate-nightly-index.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# do not complain about line length (for docstring) +# ruff: noqa: E501 + +import argparse +import json +import re +import sys +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any +from urllib.parse import quote + +if not sys.version_info >= (3, 12): + raise RuntimeError("This script requires Python 3.12 or higher.") + +INDEX_HTML_TEMPLATE = """ + + + +{items} + + +""" + + +@dataclass +class WheelFileInfo: + package_name: str + version: str + build_tag: str | None + python_tag: str + abi_tag: str + platform_tag: str + variant: str | None + filename: str + + +def parse_from_filename(file: str) -> WheelFileInfo: + """ + Parse wheel file name to extract metadata. + + The format of wheel names: + {package_name}-{version}(-{build_tag})?-{python_tag}-{abi_tag}-{platform_tag}.whl + All versions could contain a variant like '+cu129' or '.cpu' or `.rocm` (or not). + Example: + vllm-0.11.0-cp38-abi3-manylinux1_x86_64.whl + vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl + vllm-0.11.1rc8.dev14+gaa384b3c0-cp38-abi3-manylinux2014_aarch64.whl + vllm-0.11.1rc8.dev14+gaa384b3c0.cu130-cp38-abi3-manylinux1_x86_64.whl + """ + wheel_file_re = re.compile( + r"^(?P.+)-(?P[^-]+?)(-(?P[^-]+))?-(?P[^-]+)-(?P[^-]+)-(?P[^-]+)\.whl$" + ) + match = wheel_file_re.match(file) + if not match: + raise ValueError(f"Invalid wheel file name: {file}") + + package_name = match.group("package_name") + version = match.group("version") + build_tag = match.group("build_tag") + python_tag = match.group("python_tag") + abi_tag = match.group("abi_tag") + platform_tag = match.group("platform_tag") + + # extract variant from version + variant = None + if "dev" in version: + ver_after_dev = version.split("dev")[-1] + if "." in ver_after_dev: + variant = ver_after_dev.split(".")[-1] + version = version.removesuffix("." + variant) + else: + if "+" in version: + version, variant = version.split("+") + + return WheelFileInfo( + package_name=package_name, + version=version, + build_tag=build_tag, + python_tag=python_tag, + abi_tag=abi_tag, + platform_tag=platform_tag, + variant=variant, + filename=file, + ) + + +def generate_project_list(subdir_names: list[str]) -> str: + """ + Generate project list HTML content linking to each project & variant sub-directory. + """ + href_tags = [] + for name in sorted(subdir_names): + name = name.strip("/").strip(".") + href_tags.append(f' {name}/
') + return INDEX_HTML_TEMPLATE.format(items="\n".join(href_tags)) + + +def generate_package_index_and_metadata( + wheel_files: list[WheelFileInfo], wheel_base_dir: Path, index_base_dir: Path +) -> tuple[str, str]: + """ + Generate package index HTML content for a specific package, linking to actual wheel files. + """ + href_tags = [] + metadata = [] + for file in sorted(wheel_files, key=lambda x: x.filename): + relative_path = ( + wheel_base_dir.relative_to(index_base_dir, walk_up=True) / file.filename + ) + # handle with '+' in URL, and avoid double-encoding '/' and already-encoded '%2B' + # NOTE: this is AWS S3 specific behavior! + file_path_quoted = quote(relative_path.as_posix(), safe=":%/") + href_tags.append(f' {file.filename}
') + file_meta = asdict(file) + file_meta["path"] = file_path_quoted + metadata.append(file_meta) + index_str = INDEX_HTML_TEMPLATE.format(items="\n".join(href_tags)) + metadata_str = json.dumps(metadata, indent=2) + return index_str, metadata_str + + +def generate_index_and_metadata( + whl_files: list[str], + wheel_base_dir: Path, + index_base_dir: Path, + default_variant: str | None = None, + alias_to_default: str | None = None, +): + """ + Generate index for all wheel files. + + Args: + whl_files (list[str]): List of wheel files (must be directly under `wheel_base_dir`). + wheel_base_dir (Path): Base directory for wheel files. + index_base_dir (Path): Base directory to store index files. + default_variant (str | None): The default variant name, if any. + alias_to_default (str | None): Alias variant name for the default variant, if any. + + First, parse all wheel files to extract metadata. + We need to collect all wheel files for each variant, and generate an index for it (in a sub-directory). + The index for the default variant (if any) is generated in the root index directory. + + If `default_variant` is provided, all wheels must have variant suffixes, and the default variant index + is purely a copy of the corresponding variant index, with only the links adjusted. + Otherwise, all wheels without variant suffixes are treated as the default variant. + + If `alias_to_default` is provided, an additional alias sub-directory is created, it has the same content + as the default variant index, but the links are adjusted accordingly. + + Index directory structure: + index_base_dir/ (hosted at wheels.vllm.ai/{nightly,$commit,$version}/) + index.html # project list, linking to "vllm/" and other packages, and all variant sub-directories + vllm/ + index.html # package index, pointing to actual files in wheel_base_dir (relative path) + metadata.json # machine-readable metadata for all wheels in this package + cpu/ # cpu variant sub-directory + index.html + vllm/ + index.html + metadata.json + cu129/ # cu129 is actually the alias to default variant + index.html + vllm/ + index.html + metadata.json + cu130/ # cu130 variant sub-directory + index.html + vllm/ + index.html + metadata.json + ... + + metadata.json stores a dump of all wheel files' metadata in a machine-readable format: + [ + { + "package_name": "vllm", + "version": "0.10.2rc2", + "build_tag": null, + "python_tag": "cp38", + "abi_tag": "abi3", + "platform_tag": "manylinux2014_aarch64", + "variant": "cu129", + "filename": "vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl", + "path": "../vllm-0.10.2rc2%2Bcu129-cp38-abi3-manylinux2014_aarch64.whl" # to be concatenated with the directory URL and URL-encoded + }, + ... + ] + """ + + parsed_files = [parse_from_filename(f) for f in whl_files] + + if not parsed_files: + print("No wheel files found, skipping index generation.") + return + + # Group by variant + variant_to_files: dict[str, list[WheelFileInfo]] = {} + for file in parsed_files: + variant = file.variant or "default" + if variant not in variant_to_files: + variant_to_files[variant] = [] + variant_to_files[variant].append(file) + + print(f"Found variants: {list(variant_to_files.keys())}") + + # sanity check for default variant + if default_variant: + if "default" in variant_to_files: + raise ValueError( + "All wheel files must have variant suffixes when `default_variant` is specified." + ) + if default_variant not in variant_to_files: + raise ValueError( + f"Default variant '{default_variant}' not found among wheel files." + ) + + if alias_to_default: + if "default" not in variant_to_files: + # e.g. only some wheels are uploaded to S3 currently + print( + "[WARN] Alias to default variant specified, but no default variant found." + ) + elif alias_to_default in variant_to_files: + raise ValueError( + f"Alias variant name '{alias_to_default}' already exists among wheel files." + ) + else: + variant_to_files[alias_to_default] = variant_to_files["default"].copy() + print(f"Alias variant '{alias_to_default}' created for default variant.") + + # Generate index for each variant + subdir_names = set() + for variant, files in variant_to_files.items(): + if variant == "default": + variant_dir = index_base_dir + else: + variant_dir = index_base_dir / variant + subdir_names.add(variant) + + variant_dir.mkdir(parents=True, exist_ok=True) + + # gather all package names in this variant + packages = set(f.package_name for f in files) + if variant == "default": + # these packages should also appear in the "project list" + # generate after all variants are processed + subdir_names = subdir_names.union(packages) + else: + # generate project list for this variant directly + project_list_str = generate_project_list(sorted(packages)) + with open(variant_dir / "index.html", "w") as f: + f.write(project_list_str) + + for package in packages: + # filter files belonging to this package only + package_files = [f for f in files if f.package_name == package] + package_dir = variant_dir / package + package_dir.mkdir(parents=True, exist_ok=True) + index_str, metadata_str = generate_package_index_and_metadata( + package_files, wheel_base_dir, package_dir + ) + with open(package_dir / "index.html", "w") as f: + f.write(index_str) + with open(package_dir / "metadata.json", "w") as f: + f.write(metadata_str) + + # Generate top-level project list index + project_list_str = generate_project_list(sorted(subdir_names)) + with open(index_base_dir / "index.html", "w") as f: + f.write(project_list_str) + + +if __name__ == "__main__": + """ + Arguments: + --version : version string for the current build (e.g., commit hash) + --current-objects : path to JSON file containing current S3 objects listing in this version directory + --output-dir : directory to store generated index files + --alias-to-default : (optional) alias variant name for the default variant + """ + + parser = argparse.ArgumentParser( + description="Process nightly build wheel files to generate indices." + ) + parser.add_argument( + "--version", + type=str, + required=True, + help="Version string for the current build (e.g., commit hash)", + ) + parser.add_argument( + "--current-objects", + type=str, + required=True, + help="Path to JSON file containing current S3 objects listing in this version directory", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Directory to store generated index files", + ) + parser.add_argument( + "--alias-to-default", + type=str, + default=None, + help="Alias variant name for the default variant", + ) + + args = parser.parse_args() + + version = args.version + if "/" in version or "\\" in version: + raise ValueError("Version string must not contain slashes.") + current_objects_path = Path(args.current_objects) + output_dir = Path(args.output_dir) + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + + # Read current objects JSON + with open(current_objects_path) as f: + current_objects: dict[str, list[dict[str, Any]]] = json.load(f) + + # current_objects looks like from list_objects_v2 S3 API: + """ + "Contents": [ + { + "Key": "e2f56c309d2a28899c68975a7e104502d56deb8f/vllm-0.11.2.dev363+ge2f56c309-cp38-abi3-manylinux1_x86_64.whl", + "LastModified": "2025-11-28T14:00:32+00:00", + "ETag": "\"37a38339c7cdb61ca737021b968075df-52\"", + "ChecksumAlgorithm": [ + "CRC64NVME" + ], + "ChecksumType": "FULL_OBJECT", + "Size": 435649349, + "StorageClass": "STANDARD" + }, + ... + ] + """ + + # Extract wheel file keys + wheel_files = [] + for item in current_objects.get("Contents", []): + key: str = item["Key"] + if key.endswith(".whl"): + wheel_files.append(key.split("/")[-1]) # only the filename is used + + print(f"Found {len(wheel_files)} wheel files for version {version}: {wheel_files}") + + # Generate index and metadata, assuming wheels and indices are stored as: + # s3://vllm-wheels/{version}/ + # s3://vllm-wheels// + wheel_base_dir = Path(output_dir).parent / version + index_base_dir = Path(output_dir) + + generate_index_and_metadata( + whl_files=wheel_files, + wheel_base_dir=wheel_base_dir, + index_base_dir=index_base_dir, + default_variant=None, + alias_to_default=args.alias_to_default, + ) + print(f"Successfully generated index and metadata in {output_dir}") diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh index d0036f24c8d04..b5f6b2494792f 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh @@ -7,53 +7,51 @@ set -ex # allow to bind to different cores CORE_RANGE=${CORE_RANGE:-0-16} OMP_CORE_RANGE=${OMP_CORE_RANGE:-0-16} -NUMA_NODE=${NUMA_NODE:-0} -export CMAKE_BUILD_PARALLEL_LEVEL=32 +export CMAKE_BUILD_PARALLEL_LEVEL=16 # Setup cleanup remove_docker_container() { set -e; - docker rm -f cpu-test-"$NUMA_NODE" || true; + docker rm -f cpu-test || true; } trap remove_docker_container EXIT remove_docker_container # Try building the docker image -numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu . +docker build --tag cpu-test --target vllm-test -f docker/Dockerfile.cpu . -# Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +# Run the image +docker run -itd --cpuset-cpus="$CORE_RANGE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test cpu-test function cpu_tests() { set -e - export NUMA_NODE=$2 - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test bash -c " set -e pip list" # offline inference - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test bash -c " set -e python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" # Run kernel tests - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test bash -c " set -e pytest -x -v -s tests/kernels/test_onednn.py pytest -x -v -s tests/kernels/attention/test_cpu_attn.py" # basic online serving - docker exec cpu-test-"$NUMA_NODE" bash -c ' + docker exec cpu-test bash -c ' set -e - VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS vllm serve meta-llama/Llama-3.2-3B-Instruct --max-model-len 2048 & + VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS vllm serve Qwen/Qwen3-0.6B --max-model-len 2048 & server_pid=$! timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 vllm bench serve \ --backend vllm \ --dataset-name random \ - --model meta-llama/Llama-3.2-3B-Instruct \ + --model Qwen/Qwen3-0.6B \ --num-prompts 20 \ --endpoint /v1/completions kill -s SIGTERM $server_pid &' @@ -61,4 +59,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2h bash -c cpu_tests diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 2267718f75ca5..438fe522c8702 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -21,8 +21,8 @@ trap remove_docker_container EXIT remove_docker_container # Try building the docker image -numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu . -numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . +numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --progress plain --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu . +numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --progress plain --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index d49f3e2f47cf1..4d163399cfc6c 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -35,7 +35,7 @@ docker run \ echo $ZE_AFFINITY_MASK pip install tblib==3.1.0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager - python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -cc.cudagraph_mode=NONE python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 945c5e48c0090..2eaa91c04086c 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -2,6 +2,28 @@ set -ex +# ======== part 0: setup ======== + +BUCKET="vllm-wheels" +INDICES_OUTPUT_DIR="indices" +DEFAULT_VARIANT_ALIAS="cu129" # align with vLLM_MAIN_CUDA_VERSION in vllm/envs.py +PYTHON=${PYTHON_PROG:=python3} # try to read from env var, otherwise use python3 +SUBPATH=$BUILDKITE_COMMIT +S3_COMMIT_PREFIX="s3://$BUCKET/$SUBPATH/" + +# detect if python3.10+ is available +has_new_python=$($PYTHON -c "print(1 if __import__('sys').version_info >= (3,12) else 0)") +if [[ "$has_new_python" -eq 0 ]]; then + # use new python from docker + docker pull python:3-slim + PYTHON="docker run --rm -v $(pwd):/app -w /app python:3-slim python3" +fi + +echo "Using python interpreter: $PYTHON" +echo "Python version: $($PYTHON --version)" + +# ========= part 1: collect, rename & upload the wheel ========== + # Assume wheels are in artifacts/dist/*.whl wheel_files=(artifacts/dist/*.whl) @@ -10,74 +32,69 @@ if [[ ${#wheel_files[@]} -ne 1 ]]; then echo "Error: Expected exactly one wheel file in artifacts/dist/, but found ${#wheel_files[@]}" exit 1 fi - -# Get the single wheel file wheel="${wheel_files[0]}" -# Detect architecture and rename 'linux' to appropriate manylinux version -arch=$(uname -m) -if [[ $arch == "x86_64" ]]; then - manylinux_version="manylinux1" -elif [[ $arch == "aarch64" ]]; then - manylinux_version="manylinux2014" -else - echo "Warning: Unknown architecture $arch, using manylinux1 as default" - manylinux_version="manylinux1" -fi +# current build image uses ubuntu 20.04, which corresponds to manylinux_2_31 +# refer to https://github.com/mayeut/pep600_compliance?tab=readme-ov-file#acceptable-distros-to-build-wheels +manylinux_version="manylinux_2_31" # Rename 'linux' to the appropriate manylinux version in the wheel filename +if [[ "$wheel" != *"linux"* ]]; then + echo "Error: Wheel filename does not contain 'linux': $wheel" + exit 1 +fi new_wheel="${wheel/linux/$manylinux_version}" mv -- "$wheel" "$new_wheel" wheel="$new_wheel" +echo "Renamed wheel to: $wheel" # Extract the version from the wheel version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2) -echo "Version: $version" +echo "Version in wheel: $version" +pure_version="${version%%+*}" +echo "Pure version (without variant): $pure_version" -normal_wheel="$wheel" # Save the original wheel filename +# copy wheel to its own bucket +aws s3 cp "$wheel" "$S3_COMMIT_PREFIX" -# If the version contains "dev", rename it to v1.0.0.dev for consistency -if [[ $version == *dev* ]]; then - suffix="${version##*.}" - if [[ $suffix == cu* ]]; then - new_version="1.0.0.dev+${suffix}" - else - new_version="1.0.0.dev" - fi - new_wheel="${wheel/$version/$new_version}" - # use cp to keep both files in the artifacts directory - cp -- "$wheel" "$new_wheel" - wheel="$new_wheel" - version="$new_version" -fi +# ========= part 2: generate and upload indices ========== +# generate indices for all existing wheels in the commit directory +# this script might be run multiple times if there are multiple variants being built +# so we need to guarantee there is little chance for "TOCTOU" issues +# i.e., one process is generating indices while another is uploading a new wheel +# so we need to ensure no time-consuming operations happen below -# Upload the wheel to S3 -python3 .buildkite/generate_index.py --wheel "$normal_wheel" +# list all wheels in the commit directory +echo "Existing wheels on S3:" +aws s3 ls "$S3_COMMIT_PREFIX" +obj_json="objects.json" +aws s3api list-objects-v2 --bucket "$BUCKET" --prefix "$SUBPATH/" --delimiter / --output json > "$obj_json" +mkdir -p "$INDICES_OUTPUT_DIR" -# generate index for this commit -aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" -aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" - -if [[ $normal_wheel == *"cu129"* ]]; then - # only upload index.html for cu129 wheels (default wheels) as it - # is available on both x86 and arm64 - aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" - aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" +# call script to generate indicies for all existing wheels +# this indices have relative paths that could work as long as it is next to the wheel directory in s3 +# i.e., the wheels are always in s3://vllm-wheels// +# and indices can be placed in //, or /nightly/, or // +if [[ ! -z "$DEFAULT_VARIANT_ALIAS" ]]; then + alias_arg="--alias-to-default $DEFAULT_VARIANT_ALIAS" else - echo "Skipping index files for non-cu129 wheels" + alias_arg="" fi -# generate index for nightly -aws s3 cp "$wheel" "s3://vllm-wheels/nightly/" -aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" +$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" $alias_arg -if [[ $normal_wheel == *"cu129"* ]]; then - # only upload index.html for cu129 wheels (default wheels) as it - # is available on both x86 and arm64 - aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" -else - echo "Skipping index files for non-cu129 wheels" +# copy indices to // unconditionally +echo "Uploading indices to $S3_COMMIT_PREFIX" +aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "$S3_COMMIT_PREFIX" + +# copy to /nightly/ only if it is on the main branch and not a PR +if [[ "$BUILDKITE_BRANCH" == "main" && "$BUILDKITE_PULL_REQUEST" == "false" ]]; then + echo "Uploading indices to overwrite /nightly/" + aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/nightly/" fi -aws s3 cp "$wheel" "s3://vllm-wheels/$version/" -aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html" +# copy to // only if it does not have "dev" in the version +if [[ "$version" != *"dev"* ]]; then + echo "Uploading indices to overwrite /$pure_version/" + aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/$pure_version/" +fi diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 4ddf11c0b268f..ee4fdebae5675 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -39,9 +39,9 @@ steps: # if this test fails, it means the nightly torch version is not compatible with some # of the dependencies. Please check the error message and add the package to whitelist # in /vllm/tools/pre_commit/generate_nightly_torch_test.py - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking soft_fail: true source_file_dependencies: - requirements/nightly_torch_test.txt @@ -50,9 +50,9 @@ steps: - label: Async Engine, Inputs, Utils, Worker Test # 10min timeout_in_minutes: 15 - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/multimodal @@ -61,17 +61,18 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins - timeout_in_minutes: 10 - mirror_hardwares: [amdexperimental, amdproduction] +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/test_inputs.py - tests/test_outputs.py - tests/multimodal - tests/standalone_tests/lazy_imports.py + - tests/tokenizers_ - tests/transformers_utils - tests/config no_gpu: true @@ -80,6 +81,7 @@ steps: - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s tokenizers_ - pytest -v -s transformers_utils - pytest -v -s config @@ -113,9 +115,9 @@ steps: - pytest -v -s basic_correctness/test_cpu_offload.py - label: Entrypoints Unit Tests # 5min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking timeout_in_minutes: 10 working_dir: "/vllm-workspace/tests" fast_check: true @@ -212,6 +214,7 @@ steps: # test with internal dp - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py @@ -250,9 +253,9 @@ steps: - torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep - label: EPLB Algorithm Test # 5min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking timeout_in_minutes: 15 working_dir: "/vllm-workspace/tests" source_file_dependencies: @@ -308,23 +311,20 @@ steps: - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional -- label: Engine Test # 25min - timeout_in_minutes: 40 +- label: Engine Test # 9min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking source_file_dependencies: - vllm/ - tests/engine - - tests/tokenization - tests/test_sequence - tests/test_config - tests/test_logger - tests/test_vllm_port commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - # OOM in the CI unless we run this separately - - pytest -v -s tokenization - label: V1 Test e2e + engine # 30min timeout_in_minutes: 45 @@ -342,9 +342,9 @@ steps: - label: V1 Test entrypoints # 35min timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/v1 @@ -392,6 +392,20 @@ steps: commands: - pytest -v -s v1/attention +- label: Batch Invariance Tests (H100) # 10min + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + timeout_in_minutes: 25 + gpu: h100 + source_file_dependencies: + - vllm/ + - tests/v1/determinism/ + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pip install pytest-timeout pytest-forked + - pytest -v -s v1/determinism/test_batch_invariance.py + - pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py + - label: V1 Test attention (B200) # 10min timeout_in_minutes: 30 gpu: b200 @@ -402,9 +416,9 @@ steps: - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this - label: V1 Test others (CPU) # 5 mins - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/v1 @@ -496,7 +510,7 @@ steps: - label: PyTorch Compilation Unit Tests # 15min timeout_in_minutes: 30 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking torch_nightly: true @@ -513,7 +527,7 @@ steps: - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking torch_nightly: true @@ -569,7 +583,7 @@ steps: - label: Kernels Attention Test %N # 23min timeout_in_minutes: 35 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_8 # grade: Blocking source_file_dependencies: @@ -596,7 +610,7 @@ steps: - label: Kernels MoE Test %N # 40min timeout_in_minutes: 60 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_8 # grade: Blocking source_file_dependencies: @@ -623,6 +637,26 @@ steps: commands: - pytest -v -s kernels/mamba +- label: Kernels DeepGEMM Test (H100) # Nvidia-centric +# Not replicating for CUTLAS & CuTe + timeout_in_minutes: 45 + gpu: h100 + num_gpus: 1 + source_file_dependencies: + - tools/install_deepgemm.sh + - vllm/utils/deep_gemm.py + - vllm/model_executor/layers/fused_moe + - vllm/model_executor/layers/quantization + - tests/kernels/quantization/test_block_fp8.py + - tests/kernels/moe/test_deepgemm.py + - tests/kernels/moe/test_batched_deepgemm.py + - tests/kernels/attention/test_deepgemm_attention.py + commands: + - pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm + - pytest -v -s kernels/moe/test_deepgemm.py + - pytest -v -s kernels/moe/test_batched_deepgemm.py + - pytest -v -s kernels/attention/test_deepgemm_attention.py + - label: Model Executor Test # 23min timeout_in_minutes: 35 torch_nightly: true @@ -681,6 +715,7 @@ steps: # we can only upgrade after this is resolved # TODO(jerryzh168): resolve the above comment - uv pip install --system torchao==0.13.0 + - uv pip install --system conch-triton-kernels - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py - label: LM Eval Small Models # 15min @@ -900,6 +935,18 @@ steps: commands: - pytest -v -s models/language/pooling_mteb_test +- label: Multi-Modal Processor Test (CPU) + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + no_gpu: true + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py + - label: Multi-Modal Processor Test # 44min timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] @@ -1056,6 +1103,7 @@ steps: - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_flashinfer.py + - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py - label: Blackwell Fusion and Compile Tests # 30 min timeout_in_minutes: 40 @@ -1065,11 +1113,19 @@ steps: - csrc/quantization/fp4/ - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/worker/ + - vllm/v1/cudagraph_dispatcher.py - vllm/compilation/ # can affect pattern matching - vllm/model_executor/layers/layernorm.py - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py + - vllm/model_executor/layers/fused_moe/layer.py + - tests/compile/test_fusion_attn.py + - tests/compile/test_silu_mul_quant_fusion.py + - tests/compile/distributed/test_fusion_all_reduce.py + - tests/compile/distributed/test_fusions_e2e.py + - tests/compile/fullgraph/test_full_graph.py commands: - nvidia-smi - pytest -v -s tests/compile/test_fusion_attn.py @@ -1080,7 +1136,7 @@ steps: # Wrap with quotes to escape yaml - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'" # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) - - pytest -v -s tests/compile/distributed/test_full_graph.py::test_fp8_kv_scale_compile + - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile - label: Blackwell Fusion E2E Tests # 30 min timeout_in_minutes: 40 @@ -1102,7 +1158,7 @@ steps: commands: - nvidia-smi # Run all e2e fusion tests - - pytest -v -s tests/compile/test_fusions_e2e.py + - pytest -v -s tests/compile/distributed/test_fusions_e2e.py - label: ROCm GPT-OSS Eval timeout_in_minutes: 60 @@ -1217,6 +1273,7 @@ steps: - tests/v1/worker/test_worker_memory_snapshot.py commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py @@ -1252,7 +1309,7 @@ steps: - label: Plugin Tests (2 GPUs) # 40min timeout_in_minutes: 60 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_2 # grade: Blocking working_dir: "/vllm-workspace/tests" @@ -1328,7 +1385,7 @@ steps: - label: Weight Loading Multiple GPU Test # 33min timeout_in_minutes: 45 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_2 # grade: Blocking working_dir: "/vllm-workspace/tests" @@ -1428,14 +1485,14 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - pytest -v -s tests/compile/distributed/test_async_tp.py + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" - - pytest -v -s tests/compile/distributed/test_sequence_parallel.py + - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py - - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + - HIP_VISIBLE_DEVICES=0,1 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - pytest -v -s tests/v1/distributed/test_dbo.py ##### B200 test ##### @@ -1465,7 +1522,7 @@ steps: - bash .buildkite/scripts/run-prime-rl-test.sh - label: DeepSeek V2-Lite Accuracy - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking timeout_in_minutes: 60 @@ -1476,8 +1533,8 @@ steps: commands: - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010 -- label: Qwen3-30B-A3B-FP8-block Accuracy - mirror_hardwares: [amdexperimental] +- label: Qwen3-30B-A3B-FP8-block Accuracy (H100) + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking timeout_in_minutes: 60 @@ -1487,3 +1544,12 @@ steps: working_dir: "/vllm-workspace" commands: - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 + +- label: Qwen3-30B-A3B-FP8-block Accuracy (B200) + timeout_in_minutes: 60 + gpu: b200 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 \ No newline at end of file diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f1cd39ef4f948..52c848c784e53 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -57,14 +57,15 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins - timeout_in_minutes: 10 +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min + timeout_in_minutes: 20 source_file_dependencies: - vllm/ - tests/test_inputs.py - tests/test_outputs.py - tests/multimodal - tests/standalone_tests/lazy_imports.py + - tests/tokenizers_ - tests/transformers_utils - tests/config no_gpu: true @@ -73,6 +74,7 @@ steps: - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s tokenizers_ - pytest -v -s transformers_utils - pytest -v -s config @@ -192,6 +194,7 @@ steps: # test with internal dp - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py @@ -275,21 +278,18 @@ steps: - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional -- label: Engine Test # 25min - timeout_in_minutes: 40 +- label: Engine Test # 9min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/engine - - tests/tokenization - tests/test_sequence - tests/test_config - tests/test_logger - tests/test_vllm_port commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - # OOM in the CI unless we run this separately - - pytest -v -s tokenization - label: V1 Test e2e + engine # 30min timeout_in_minutes: 45 @@ -390,20 +390,24 @@ steps: - examples/ commands: - pip install tensorizer # for tensorizer test + # for basic + - python3 offline_inference/basic/chat.py - python3 offline_inference/basic/generate.py --model facebook/opt-125m - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 - - python3 offline_inference/basic/chat.py - - python3 offline_inference/prefix_caching.py - - python3 offline_inference/llm_engine_example.py - - python3 offline_inference/audio_language.py --seed 0 - - python3 offline_inference/vision_language.py --seed 0 - - python3 offline_inference/vision_language_pooling.py --seed 0 - - python3 offline_inference/vision_language_multi_image.py --seed 0 - - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py + # for multi-modal models + - python3 offline_inference/audio_language.py --seed 0 + - python3 offline_inference/vision_language.py --seed 0 + - python3 offline_inference/vision_language_multi_image.py --seed 0 + - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 + # for pooling models + - python3 pooling/pooling/vision_language_pooling.py --seed 0 + # for features demo + - python3 offline_inference/prefix_caching.py + - python3 offline_inference/llm_engine_example.py + - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 @@ -631,6 +635,7 @@ steps: # we can only upgrade after this is resolved # TODO(jerryzh168): resolve the above comment - uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129 + - uv pip install --system conch-triton-kernels - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py - label: LM Eval Small Models # 53min @@ -818,14 +823,24 @@ steps: commands: - pytest -v -s models/language/pooling_mteb_test -- label: Multi-Modal Processor Test # 44min +- label: Multi-Modal Processor Test (CPU) + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + no_gpu: true + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py + +- label: Multi-Modal Processor Test timeout_in_minutes: 60 source_file_dependencies: - vllm/ - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s models/multimodal/processing + - pytest -v -s models/multimodal/processing/test_tensor_schema.py - label: Multi-Modal Models Test (Standard) # 60min timeout_in_minutes: 80 @@ -902,11 +917,12 @@ steps: - label: Transformers Nightly Models Test working_dir: "/vllm-workspace/" optional: true + soft_fail: true commands: - pip install --upgrade git+https://github.com/huggingface/transformers - - pytest -v -s tests/models/test_initialization.py -k 'not (Ultravox or Phi4Multimodal or MiniCPMO or Lfm2Moe or RobertaForSequenceClassification or Ovis2_5 or DeepseekOCR or KimiVL)' + - pytest -v -s tests/models/test_initialization.py - pytest -v -s tests/models/test_transformers.py - # - pytest -v -s tests/models/multimodal/processing/ + - pytest -v -s tests/models/multimodal/processing/ - pytest -v -s tests/models/multimodal/test_mapping.py - python3 examples/offline_inference/basic/chat.py - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl @@ -1116,6 +1132,7 @@ steps: # https://github.com/NVIDIA/nccl/issues/1838 - export NCCL_CUMEM_HOST_ENABLE=0 - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py @@ -1299,11 +1316,11 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - pytest -v -s tests/compile/distributed/test_async_tp.py + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py - - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" - - pytest -v -s tests/distributed/test_sequence_parallel.py + - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - pytest -v -s tests/v1/distributed/test_dbo.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3247408e1163e..d6447649cd89a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -146,9 +146,10 @@ mkdocs.yaml @hmellor /requirements/kv_connectors.txt @NickLucche # Pooling models -/examples/*/pooling/ @noooop +/examples/pooling @noooop /tests/models/*/pooling* @noooop /tests/entrypoints/pooling @noooop +/vllm/entrypoints/pooling @noooop /vllm/config/pooler.py @noooop /vllm/pooling_params.py @noooop /vllm/model_executor/layers/pooler.py @noooop diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml index c3e132a536a42..56fbe5ca704a1 100644 --- a/.github/workflows/cleanup_pr_body.yml +++ b/.github/workflows/cleanup_pr_body.yml @@ -13,10 +13,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Set up Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: '3.12' diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml index 7d565ef9f2e45..629966b959330 100644 --- a/.github/workflows/issue_autolabel.yml +++ b/.github/workflows/issue_autolabel.yml @@ -105,6 +105,31 @@ jobs: } ], }, + cpu: { + // Keyword search - matches whole words only (with word boundaries) + keywords: [ + { + term: "CPU Backend", + searchIn: "title" + }, + { + term: "x86", + searchIn: "title" + }, + { + term: "ARM", + searchIn: "title" + }, + { + term: "Apple Silicon", + searchIn: "title" + }, + { + term: "IBM Z", + searchIn: "title" + }, + ], + }, // Add more label configurations here as needed // example: { // keywords: [...], diff --git a/.github/workflows/macos-smoke-test.yml b/.github/workflows/macos-smoke-test.yml index a183033c9adde..3a12c4b3a8300 100644 --- a/.github/workflows/macos-smoke-test.yml +++ b/.github/workflows/macos-smoke-test.yml @@ -12,7 +12,7 @@ jobs: timeout-minutes: 30 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: astral-sh/setup-uv@v7 with: diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index e21d13b8161f3..a03b979ad761d 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -16,8 +16,8 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: "3.12" - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" diff --git a/CMakeLists.txt b/CMakeLists.txt index a4cf51d17e982..e09972fe71995 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -136,7 +136,7 @@ elseif(HIP_FOUND) # ROCm 5.X and 6.X if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND - NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM}) + Torch_VERSION VERSION_LESS ${TORCH_SUPPORTED_VERSION_ROCM}) message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} " "expected for ROCm build, saw ${Torch_VERSION} instead.") endif() @@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build Marlin kernels if we are building for at least some compatible archs. # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. - # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + + # marlin arches for fp16 output + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + # marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX) + cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + # marlin arches for fp8 input + # - sm80 doesn't support fp8 computation + # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction + # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) + cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + if (MARLIN_ARCHS) # @@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(MARLIN_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) + list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) + set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") - message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}") - message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") - if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH} - OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH}) + if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=$PYTHONPATH - ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} + ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} RESULT_VARIABLE marlin_generation_result OUTPUT_VARIABLE marlin_generation_result OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log @@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "\nCheck the log for details: " "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") else() - set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH} - CACHE STRING "Last run Marlin generate script hash" FORCE) + set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + CACHE STRING "Last run Marlin generate script hash and arch" FORCE) message(STATUS "Marlin generation completed successfully.") endif() else() message(STATUS "Marlin generation script has not changed, skipping generation.") endif() - file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu") + file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" CUDA_ARCHS "${MARLIN_ARCHS}") @@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC} PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") endif() - list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) + file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_BF16_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC}) + + if (MARLIN_FP8_ARCHS) + file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_FP8_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC}) + endif() + set(MARLIN_SRCS "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") set_gencode_flags_for_srcs( @@ -604,12 +637,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1") message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") else() message(STATUS "Not building NVFP4 as no compatible archs were found.") @@ -938,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${CUDA_ARCHS}") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") - # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + # moe marlin arches + # note that we always set `use_atomic_add=False` for moe marlin now, + # so we don't need 9.0 for bf16 atomicAdd PTX + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + # moe marlin arches for fp8 input + # - sm80 doesn't support fp8 computation + # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction + # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) + cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) # @@ -949,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(MOE_MARLIN_GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) + list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) + set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") - message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}") - message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") + message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}") - if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} - OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) + if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} + OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=$PYTHONPATH - ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} + ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} RESULT_VARIABLE moe_marlin_generation_result OUTPUT_VARIABLE moe_marlin_generation_output OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log @@ -971,7 +1016,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "\nCheck the log for details: " "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") else() - set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH} + set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} CACHE STRING "Last run Marlin MOE generate script hash" FORCE) message(STATUS "Marlin MOE generation completed successfully.") endif() @@ -979,16 +1024,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Marlin MOE generation script has not changed, skipping generation.") endif() - file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu") + file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu") + list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu") set_gencode_flags_for_srcs( - SRCS "${MOE_WNAA16_MARLIN_SRC}" + SRCS "${MARLIN_MOE_SRC}" CUDA_ARCHS "${MARLIN_MOE_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) - set_source_files_properties(${MOE_WNAA16_MARLIN_SRC} + set_source_files_properties(${MARLIN_MOE_SRC} PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC}) - list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) + if (MARLIN_MOE_FP8_ARCHS) + file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_FP8_SRC}" + CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + set_source_files_properties(${MARLIN_MOE_FP8_SRC} + PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false") + endif() + list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC}) + endif() message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") else() diff --git a/README.md b/README.md index 033e1035d8916..abbb63158f166 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio *Latest News* 🔥 +- [2025/11] We hosted [vLLM Bangkok Meetup](https://luma.com/v0f647nv). We explored vLLM and LMCache inference and low-resource language adaptation with speakers from Embedded LLM, AMD, and Red Hat. Please find the meetup slides [here](https://drive.google.com/drive/folders/1H0DS57F8HQ5q3kSOSoRmucPJWL3E0A_X?usp=sharing). - [2025/11] We hosted [the first vLLM Europe Meetup in Zurich](https://luma.com/0gls27kb) focused on quantization, distributed inference, and reinforcement learning at scale with speakers from Mistral, IBM, and Red Hat. Please find the meetup slides [here](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) and recording [here](https://www.youtube.com/watch?v=6m6ZE6yVEDI) - [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link). - [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6). diff --git a/benchmarks/auto_tune/README.md b/benchmarks/auto_tune/README.md index d1bdb4c43f10b..9a9600e08dafe 100644 --- a/benchmarks/auto_tune/README.md +++ b/benchmarks/auto_tune/README.md @@ -83,7 +83,7 @@ MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number ``` -#### 2. Maximize Throughput with a Latency Requirement +### 2. Maximize Throughput with a Latency Requirement - **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms. - **Configuration**: @@ -96,7 +96,7 @@ MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=500 ``` -#### 3. Maximize Throughput with Prefix Caching and Latency Requirements +### 3. Maximize Throughput with Prefix Caching and Latency Requirements - **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms. - **Configuration**: diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 4021fede72153..d69d74ca61f54 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -620,7 +620,7 @@ def get_tokenizer( kwargs["use_fast"] = False if tokenizer_mode == "mistral": try: - from vllm.transformers_utils.tokenizer import MistralTokenizer + from vllm.tokenizers import MistralTokenizer except ImportError as e: raise ImportError( "MistralTokenizer requires vllm package.\n" diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index dedb564fffac8..cac401456b62a 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -108,7 +108,10 @@ def benchmark_batched_propose(args): device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), ) # monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 28fc383a318dd..e6391134ff932 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -40,7 +40,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.utils.argparse_utils import FlexibleArgumentParser try: - from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.tokenizers import get_tokenizer except ImportError: from backend_request_func import get_tokenizer diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 55001cf3722a0..df122b4c5e8db 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -46,7 +46,7 @@ from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase try: - from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.tokenizers import get_tokenizer except ImportError: from backend_request_func import get_tokenizer diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 8787724d77cfb..ac78c019a59e5 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: b_q_weight=w_q, b_bias=None, b_scales=w_s, + a_scales=None, global_scale=None, b_zeros=w_zp, g_idx=g_idx, diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 12ca9214b1f95..48d790aec9e07 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -263,7 +263,7 @@ def bench_run( results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -273,7 +273,7 @@ def bench_run( results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/cmake/utils.cmake b/cmake/utils.cmake index ca0062ba4fabe..5047c354ff7d2 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -495,7 +495,13 @@ function (define_extension_target MOD_NAME) set(SOABI_KEYWORD "") endif() - if (ARG_USE_SABI) + run_python(IS_FREETHREADED_PYTHON + "import sysconfig; print(1 if sysconfig.get_config_var(\"Py_GIL_DISABLED\") else 0)" + "Failed to determine whether interpreter is free-threaded") + + # Free-threaded Python doesn't yet support the stable ABI (see PEP 803/809), + # so avoid using the stable ABI under free-threading only. + if (ARG_USE_SABI AND NOT IS_FREETHREADED_PYTHON) Python_add_library(${MOD_NAME} MODULE USE_SABI ${ARG_USE_SABI} ${SOABI_KEYWORD} "${ARG_SOURCES}") else() Python_add_library(${MOD_NAME} MODULE ${SOABI_KEYWORD} "${ARG_SOURCES}") diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 229d9862fb670..27d1e990c611e 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel( scalar_t* output, float* output_lse, const scalar_t* prefix_output, const float* prefix_lse, const scalar_t* suffix_output, const float* suffix_lse, const uint num_tokens, const uint num_heads, - const uint head_size) { + const uint head_size, const uint prefix_head_stride, + const uint output_head_stride) { using pack_128b_t = uint4; const uint pack_size = 16 / sizeof(scalar_t); const uint threads_per_head = head_size / pack_size; @@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel( const uint head_idx = token_head_idx % num_heads; const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. - const uint head_offset = - token_idx * num_heads * head_size + head_idx * head_size; - const scalar_t* prefix_head_ptr = prefix_output + head_offset; - const scalar_t* suffix_head_ptr = suffix_output + head_offset; - scalar_t* output_head_ptr = output + head_offset; + const uint src_head_offset = token_idx * num_heads * prefix_head_stride + + head_idx * prefix_head_stride; + const uint dst_head_offset = token_idx * num_heads * output_head_stride + + head_idx * output_head_stride; + const scalar_t* prefix_head_ptr = prefix_output + src_head_offset; + const scalar_t* suffix_head_ptr = suffix_output + src_head_offset; + scalar_t* output_head_ptr = output + dst_head_offset; float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; @@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel( reinterpret_cast(prefix_lse.data_ptr()), \ reinterpret_cast(suffix_output.data_ptr()), \ reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ - num_heads, head_size); \ + num_heads, head_size, prefix_head_stride, output_head_stride); \ } /*@brief Merges the attention states from prefix and suffix @@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output, const uint num_tokens = output.size(0); const uint num_heads = output.size(1); const uint head_size = output.size(2); + const uint prefix_head_stride = prefix_output.stride(1); + const uint output_head_stride = output.stride(1); const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); - TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1, - "output heads must be contiguous in memory"); - TORCH_CHECK( - prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1, - "prefix_output heads must be contiguous in memory"); - TORCH_CHECK( - suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1, - "suffix_output heads must be contiguous in memory"); float* output_lse_ptr = nullptr; if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index 5199ba2af024f..3dacfc7b2b7a3 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -51,12 +51,13 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { if (node_id != -1) { node_ids.insert(node_id); } - TORCH_WARN(node_id == mem_node_id, "CPU ", cpu_id, " is on NUMA node ", - node_id, ", but CPU ", omp_cpu_ids.front(), - " is on NUMA node ", mem_node_id, - ". All CPUs should be on the same NUMA node for optimal " - "performance. Memory will be bound to NUMA node ", - mem_node_id, "."); + if (node_id != mem_node_id) { + TORCH_WARN("CPU ", cpu_id, " is on NUMA node ", node_id, ", but CPU ", + omp_cpu_ids.front(), " is on NUMA node ", mem_node_id, + ". All CPUs should be on the same NUMA node for optimal " + "performance. Memory will be bound to NUMA node ", + mem_node_id, "."); + } } // Concatenate all node_ids into a single comma-separated string if (!node_ids.empty()) { diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp index df47bb8dd1d7d..58dc402016881 100644 --- a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -93,16 +93,16 @@ torch::Tensor dynamic_4bit_int_moe_cpu( } auto Y_all = at::empty({offsets[E], H}, x_c.options()); - at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { + at::parallel_for(0, offsets[E], 0, [&](int64_t idx_begin, int64_t idx_end) { c10::InferenceMode guard; - for (int64_t e = e_begin; e < e_end; ++e) { - const int64_t te = counts[e]; - if (te == 0) { + for (int64_t e = 0; e < E; ++e) { + int64_t start = std::max(offsets[e], idx_begin); + int64_t end = std::min(offsets[e + 1], idx_end); + int64_t te = end - start; + if (te <= 0) { continue; } - const int64_t start = offsets[e]; - auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); auto w13_e = w13_packed.select(/*dim=*/0, e); diff --git a/csrc/moe/marlin_moe_wna16/.gitignore b/csrc/moe/marlin_moe_wna16/.gitignore index 77088552b85b4..ba805f9250ece 100644 --- a/csrc/moe/marlin_moe_wna16/.gitignore +++ b/csrc/moe/marlin_moe_wna16/.gitignore @@ -1 +1,2 @@ -kernel_*.cu \ No newline at end of file +sm*_kernel_*.cu +kernel_selector.h diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index be5b68cc53e6f..88f1055337fd5 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -4,134 +4,282 @@ import glob import itertools import os import subprocess +import sys import jinja2 -FILE_HEAD = """ -// auto generated by generate.py -// clang-format off +ARCHS = [] +SUPPORT_FP8 = False +for arch in sys.argv[1].split(","): + arch = arch[: arch.index(".") + 2].replace(".", "") + arch = int(arch) + # only SM89 and SM120 fully support + # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM90 and SM100 can use this PTX, but it’s simulated + # with FP16 MMA, so it cannot achieve any acceleration. + if arch in [89, 120]: + SUPPORT_FP8 = True +FILE_HEAD_COMMENT = """ +// auto generated by generate_kernels.py +// clang-format off +""".lstrip() + +FILE_HEAD = ( + FILE_HEAD_COMMENT + + """ #include "kernel.h" #include "marlin_template.h" namespace MARLIN_NAMESPACE_NAME { -""".strip() +""" +) TEMPLATE = ( "template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " + "{{a_type_id}}, " + "{{b_type_id}}, " + "{{c_type_id}}, " "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " + "{{m_block_size_8}}, " "{{stages}}, " "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" + "{{is_zp_float}}>" "( MARLIN_KERNEL_PARAMS );" ) -# int8 with zero point case (vllm::kU8) is also supported, -# we don't add it to reduce wheel size. -SCALAR_TYPES = [ - "vllm::kU4", - "vllm::kU4B8", - "vllm::kU8B128", - "vllm::kFE4M3fn", - "vllm::kFE2M1f", -] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] -# group_blocks: -# = 0 : act order case -# = -1 : channelwise quantization -# > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] -DTYPES = ["fp16", "bf16"] + +QUANT_CONFIGS = [ + # AWQ-INT4 + { + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 + { + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # AWQ-INT8 + { + "b_type": "kU8B128", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # FP8 + { + "b_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 8], + }, + # NVFP4 + { + "b_type": "kFE2M1f", + "s_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [1], + }, + # MXFP4 + { + "a_type": ["kBFloat16"], + "b_type": "kFE2M1f", + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [2], + }, + # AWQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # AWQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # MXFP4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kFE2M1f", + "c_type": ["kBFloat16"], + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [2], + }, +] def remove_old_kernels(): - for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"): subprocess.call(["rm", "-f", filename]) + filename = os.path.dirname(__file__) + "/kernel_selector.h" + subprocess.call(["rm", "-f", filename]) + def generate_new_kernels(): - for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + result_dict = {} + + for quant_config in QUANT_CONFIGS: + c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) + a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"]) + b_type = quant_config["b_type"] + all_group_blocks = quant_config["group_blocks"] + all_m_blocks = quant_config["thread_m_blocks"] + all_thread_configs = quant_config["thread_configs"] + + for a_type, c_type in itertools.product(a_types, c_types): + if not SUPPORT_FP8 and a_type == "kFE4M3fn": + continue + if "16" in a_type and "16" in c_type and a_type != c_type: + continue + s_type = quant_config.get("s_type", c_type) + if (a_type, b_type, c_type) not in result_dict: + result_dict[(a_type, b_type, c_type)] = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + all_group_blocks, all_m_blocks, all_thread_configs + ): + thread_k, thread_n, threads = thread_configs + + if threads == 256: + # for small batch (m_blocks == 1), + # we only need (128, 128, 256) + # for large batch (m_blocks > 1), + # we only need (64, 256, 256) + if m_blocks <= 1 and (thread_k, thread_n) != (128, 128): + continue + if m_blocks > 1 and (thread_k, thread_n) != (64, 256): + continue + + config = { + "threads": threads, + "s_type": s_type, + "thread_m_blocks": max(m_blocks, 1), + "thread_k_blocks": thread_k // 16, + "thread_n_blocks": thread_n // 16, + "m_block_size_8": "true" if m_blocks == 0.5 else "false", + "stages": "pipe_stages", + "group_blocks": group_blocks, + "is_zp_float": "false", + } + + result_dict[(a_type, b_type, c_type)].append(config) + + kernel_selector_str = FILE_HEAD_COMMENT + + for (a_type, b_type, c_type), config_list in result_dict.items(): all_template_str_list = [] - - for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS - ): - # act order case only support gptq-int4 and gptq-int8 - if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", - "vllm::kU8B128", - ]: - continue - if thread_configs[2] == 256: - # for small batch (m_blocks == 1), we only need (128, 128, 256) - # for large batch (m_blocks > 1), we only need (64, 256, 256) - if m_blocks <= 1 and thread_configs[0] != 128: - continue - if m_blocks > 1 and thread_configs[0] != 64: - continue - - # we only support channelwise quantization and group_size == 128 - # for fp8 - if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: - continue - # nvfp4 only supports group_size == 16 - # mxfp4 only supports group_size == 32 - if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: - continue - # other quantization methods don't support group_size = 16 - if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: - continue - - k_blocks = thread_configs[0] // 16 - n_blocks = thread_configs[1] // 16 - threads = thread_configs[2] - - c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" - - if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: - s_type = "vllm::kFE4M3fn" - elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: - s_type = "vllm::kFE8M0fnu" - if dtype == "fp16": - # we cannot safely dequantize e8m0 to fp16, so skip this - continue - elif dtype == "fp16": - s_type = "vllm::kFloat16" - elif dtype == "bf16": - s_type = "vllm::kBFloat16" - + for config in config_list: + s_type = config["s_type"] template_str = jinja2.Template(TEMPLATE).render( - scalar_t=c_dtype, - w_type_id=scalar_type + ".id()", - s_type_id=s_type + ".id()", - threads=threads, - thread_m_blocks=max(m_blocks, 1), - thread_n_blocks=n_blocks, - thread_k_blocks=k_blocks, - m_block_size_8=m_blocks == 0.5, - stages="pipe_stages", - group_blocks=group_blocks, - is_zp_float=False, + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + all_template_str_list.append(template_str) + + conditions = [ + f"a_type == vllm::{a_type}", + f"b_type == vllm::{b_type}", + f"c_type == vllm::{c_type}", + f"s_type == vllm::{s_type}", + f"threads == {config['threads']}", + f"thread_m_blocks == {config['thread_m_blocks']}", + f"thread_n_blocks == {config['thread_n_blocks']}", + f"thread_k_blocks == {config['thread_k_blocks']}", + f"m_block_size_8 == {config['m_block_size_8']}", + f"group_blocks == {config['group_blocks']}", + f"is_zp_float == {config['is_zp_float']}", + ] + conditions = " && ".join(conditions) + + if kernel_selector_str == FILE_HEAD_COMMENT: + kernel_selector_str += f"if ({conditions})\n kernel = " + else: + kernel_selector_str += f"else if ({conditions})\n kernel = " + + kernel_template2 = ( + "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " + "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " + "{{thread_n_blocks}}, {{thread_k_blocks}}, " + "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " + "{{is_zp_float}}>;" ) - all_template_str_list.append(template_str) + kernel_selector_str += ( + jinja2.Template(kernel_template2).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + + "\n" + ) file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + if a_type == "kFE4M3fn": + filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + else: + filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + + filename = filename.lower() with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) + if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: + kernel_selector_str += ( + "else if (a_type == vllm::kFE4M3fn)\n" + " TORCH_CHECK(false, " + '"marlin kernel with fp8 activation is not built.");' + ) + + with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f: + f.write(kernel_selector_str) + if __name__ == "__main__": remove_old_kernels() diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index 6190f7ee21ece..57f5a17932d44 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -11,8 +11,9 @@ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ const int4 *__restrict__ b_bias_ptr, \ + const float *__restrict__ a_scales_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ scale2_ptr, \ + const uint16_t *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int32_t *__restrict__ sorted_token_ids_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \ @@ -20,12 +21,13 @@ const float *__restrict__ topk_weights_ptr, int top_k, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ - bool use_fp32_reduce, int max_shared_mem + bool use_fp32_reduce namespace MARLIN_NAMESPACE_NAME { -template shared // fetch pipeline - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -76,8 +77,8 @@ __global__ void Marlin( int prob_k, // reduction dimension k int* locks, // extra global storage for barrier synchronization bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce, // whether to use fp32 global reduce - int max_shared_mem) {} + bool use_fp32_reduce // whether to use fp32 global reduce +) {} } // namespace MARLIN_NAMESPACE_NAME @@ -85,65 +86,148 @@ __global__ void Marlin( // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { +template +__device__ inline void mma( + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragC& frag_c, int idx = 0) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]), + "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]), + "r"(c[1]), "r"(c[2]), "r"(c[3])); + } + } else if (k_size == 32) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } -template +template __device__ inline void mma_trans( - const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - const typename ScalarType::FragB& frag_b2, - typename ScalarType::FragC& frag_c) { + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + const typename MarlinScalarType::FragB& frag_b2, + typename MarlinScalarType::FragC& frag_c) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); const uint32_t* b2 = reinterpret_cast(&frag_b2); float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), + "r"(c[3])); + } } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200 + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + #else + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + #endif + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -template -__device__ inline void ldsm(typename ScalarType::FragA& frag_a, +template +__device__ inline void ldsm(typename MarlinScalarType::FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); @@ -167,47 +251,54 @@ __device__ inline void ldsm(typename ScalarType::FragA& frag_a, // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, +template +__device__ inline void scale(typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s, int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s = MarlinScalarType::num2num2( + reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } -template +template __device__ inline void scale_and_sub( - typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s2 = ScalarType::num2num2(s); - scalar_t2 zp2 = ScalarType::num2num2(zp); + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t s, + typename MarlinScalarType::scalar_t zp) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s2 = MarlinScalarType::num2num2(s); + scalar_t2 zp2 = MarlinScalarType::num2num2(zp); frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); } -template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); +template +__device__ inline void sub_zp( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t2& frag_zp, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 zp = MarlinScalarType::num2num2( + reinterpret_cast(&frag_zp)[i]); frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); } // Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; +template +__device__ inline void scale4( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s_1, + typename MarlinScalarType::FragS& frag_s_2, + typename MarlinScalarType::FragS& frag_s_3, + typename MarlinScalarType::FragS& frag_s_4, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s_val_1_2; s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; @@ -221,12 +312,13 @@ __device__ inline void scale4(typename ScalarType::FragB& frag_b, } // Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { +template +__device__ inline void scale_float( + float* c, typename MarlinScalarType::FragS& s) { + using scalar_t = typename MarlinScalarType::scalar_t; scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); + c[0] = __fmul_rn(c[0], MarlinScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], MarlinScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. @@ -278,9 +370,10 @@ __device__ inline void wait_negative_and_add(int* lock) { __syncthreads(); } -template ; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - using FragZP = typename ScalarType::FragZP; + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 890 + // FP8 computation is only supported for Ada Lovelace or newer architectures. + if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; + #endif + + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + using Adtype = MarlinScalarType; + using Cdtype = MarlinScalarType; + + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + using scalar_32bit_t = typename MarlinScalarType::scalar_32bit_t; + + using c_scalar_t = typename MarlinScalarType::scalar_t; + using c_scalar_t2 = typename MarlinScalarType::scalar_t2; + + using FragA = typename MarlinScalarType::FragA; + using FragB = typename MarlinScalarType::FragB; + using FragC = typename MarlinScalarType::FragC; + using FragS = typename MarlinScalarType::FragS; + using FragZP = typename MarlinScalarType::FragZP; extern __shared__ int4 sh[]; - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + static constexpr auto a_type = vllm::ScalarType::from_id(a_type_id); + static constexpr auto b_type = vllm::ScalarType::from_id(b_type_id); + static constexpr auto c_type = vllm::ScalarType::from_id(c_type_id); static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || s_type == vllm::kFE8M0fnu && group_blocks == 2); } else if constexpr (std::is_same::value) { @@ -355,34 +472,37 @@ __global__ void Marlin( static_assert(s_type == vllm::kFloat16); } - constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; - constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || - w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + constexpr bool is_a_8bit = a_type.size_bits() == 8; + if constexpr (!is_a_8bit) { + static_assert(std::is_same::value); + } + constexpr bool has_zp = b_type == vllm::kU4 || b_type == vllm::kU8; + constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kU4B8 || b_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - w_type == vllm::kFE4M3fn || - w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || + is_a_8bit || b_type == vllm::kFE4M3fn || + b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || - has_zp && !is_zp_float && !(w_type == vllm::kU8); + has_zp && !is_zp_float && !(b_type == vllm::kU8); - scalar_t2 global_scale; + c_scalar_t2 global_scale; constexpr bool has_act_order = group_blocks == 0; - constexpr int pack_factor = 32 / w_type.size_bits(); + constexpr int pack_factor = 32 / b_type.size_bits(); static_assert(thread_m_blocks == 1 || !m_block_size_8); - constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; const int scales_expert_stride = - prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8); + prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); const int b_bias_expert_stride = prob_n / 8; // parallel: num valid moe blocks - int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; int parallel = num_tokens_past_padded / moe_block_size; int num_valid_blocks = parallel; if (is_ep) { @@ -395,7 +515,23 @@ __global__ void Marlin( int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + int global_mn_tiles = parallel * n_tiles; + int part2_mn_tiles = global_mn_tiles; + int part1_mn_iters = 0; + bool in_part2 = false; + + // we use DP + two-tile SK here + // part1: DP + // part2: two-tile SK + // see https://github.com/vllm-project/vllm/pull/24722 for more details + if (global_mn_tiles > gridDim.x) { + part2_mn_tiles = global_mn_tiles % gridDim.x; + if (part2_mn_tiles * 3 <= gridDim.x) part2_mn_tiles += gridDim.x; + part1_mn_iters = (global_mn_tiles - part2_mn_tiles) / gridDim.x; + } + + int iters = div_ceil(k_tiles * part2_mn_tiles, gridDim.x); if constexpr (!has_act_order && group_blocks != -1) { if (group_blocks >= thread_k_blocks) { @@ -407,14 +543,15 @@ __global__ void Marlin( } } - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + int slice_row = 0; + int slice_col_par = blockIdx.x; + int slice_col; + int slice_iters = + k_tiles; // number of threadblock tiles in the current slice + // total number of active threadblocks in the current slice + int slice_count = 1; + // index of threadblock in current slice; numbered bottom to top + int slice_idx = 0; int par_id = 0; int block_id = -1; @@ -422,85 +559,89 @@ __global__ void Marlin( int old_expert_id = 0; int64_t B_expert_off = 0; - int4* sh_block_sorted_ids_int4 = sh; + float* sh_a_s = reinterpret_cast(sh); + int4* sh_block_sorted_ids_int4 = sh + (is_a_8bit ? (4 * thread_m_blocks) : 0); int4* sh_rd_block_sorted_ids_int4 = sh_block_sorted_ids_int4 + moe_block_size / 4; int4* sh_block_topk_weights_int4 = sh_rd_block_sorted_ids_int4 + moe_block_size / 4; // sh_block_topk_weights_int4 only need (moe_block_size / 4); // but we pad to align to 256 bytes - int4* sh_new = - sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size; + int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 2; int32_t* sh_block_sorted_ids = reinterpret_cast(sh_block_sorted_ids_int4); int32_t* sh_rd_block_sorted_ids = reinterpret_cast(sh_rd_block_sorted_ids_int4); - scalar_t2* sh_block_topk_weights = - reinterpret_cast(sh_block_topk_weights_int4); + c_scalar_t2* sh_block_topk_weights = + reinterpret_cast(sh_block_topk_weights_int4); int32_t block_num_valid_tokens = 0; int32_t locks_off = 0; // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - slice_col = slice_col_par % n_tiles; - par_id = slice_col_par / n_tiles; - } - if (parallel * n_tiles >= gridDim.x) { - // when parallel * n_tiles >= sms + if (part2_mn_tiles >= gridDim.x) { + // when part2_mn_tiles >= sms // then there are at most $sms$ conflict tile blocks locks_off = blockIdx.x; } else { locks_off = (iters * blockIdx.x) / k_tiles - 1; } + int prob_m_top_k = prob_m * top_k; // read moe block data given block_id // block_sorted_ids / block_num_valid_tokens / block_topk_weights auto read_moe_block_data = [&](int block_id) { block_num_valid_tokens = moe_block_size; + + cp_async4_pred(sh_block_sorted_ids_int4 + threadIdx.x, + reinterpret_cast(sorted_token_ids_ptr) + + (block_id * moe_block_size / 4 + threadIdx.x), + threadIdx.x < moe_block_size / 4); + + cp_async_fence(); + cp_async_wait<0>(); + + __syncthreads(); + + if (threadIdx.x >= threads - 32) { + constexpr int size_per_thread = div_ceil(moe_block_size, 32); + int lane_id = threadIdx.x - (threads - 32); + + int local_count = 0; #pragma unroll - for (int i = 0; i < moe_block_size / 4; i++) { - int4 sorted_token_ids_int4 = reinterpret_cast( - sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; - int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); - #pragma unroll - for (int j = 0; j < 4; j++) { - if (sorted_token_ids[j] >= prob_m * top_k) { - block_num_valid_tokens = i * 4 + j; - break; + for (int i = 0; i < size_per_thread; i++) { + int j = lane_id * size_per_thread + i; + if (j < moe_block_size) { + int idx = sh_block_sorted_ids[j]; + if (idx < prob_m_top_k) local_count++; } } - if (block_num_valid_tokens != moe_block_size) break; + + block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count); + + if (lane_id == 0) + reinterpret_cast(sh_new)[0] = block_num_valid_tokens; + } + + if (threadIdx.x < moe_block_size) { + int idx = sh_block_sorted_ids[threadIdx.x]; + sh_rd_block_sorted_ids[threadIdx.x] = idx / top_k; + + if (mul_topk_weights) { + idx = idx < prob_m_top_k ? idx : 0; + c_scalar_t2 topk_weight_val = + Cdtype::num2num2(Cdtype::float2num(topk_weights_ptr[idx])); + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + topk_weight_val = __hmul2(topk_weight_val, global_scale); + } + sh_block_topk_weights[threadIdx.x] = topk_weight_val; + } } __syncthreads(); - int tid4 = threadIdx.x / 4; - if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { - sh_block_sorted_ids_int4[tid4] = reinterpret_cast( - sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; - #pragma unroll - for (int i = 0; i < 4; i++) - sh_rd_block_sorted_ids[tid4 * 4 + i] = - sh_block_sorted_ids[tid4 * 4 + i] / top_k; - - if (mul_topk_weights) { - #pragma unroll - for (int i = 0; i < 4; i++) { - int idx = tid4 * 4 + i; - idx = idx < block_num_valid_tokens ? idx : 0; - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - sh_block_topk_weights[idx] = __hmul2( - global_scale, Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[idx]]))); - } else { - sh_block_topk_weights[idx] = Dtype::num2num2( - Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); - } - } - } - } + block_num_valid_tokens = reinterpret_cast(sh_new)[0]; __syncthreads(); }; @@ -511,9 +652,8 @@ __global__ void Marlin( old_expert_id = expert_id; if (num_invalid_blocks > 0) { - int skip_count = block_id == -1 ? par_id : 0; - block_id++; - for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) { + int skip_count = par_id; + for (int i = 0; i < num_tokens_past_padded / moe_block_size; i++) { expert_id = expert_ids_ptr[i]; if (expert_id != -1) { if (skip_count == 0) { @@ -528,9 +668,9 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - uint16_t val = scale2_ptr[expert_id]; - global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + uint16_t val = global_scale_ptr[expert_id]; + global_scale = Cdtype::num2num2(*reinterpret_cast(&val)); } B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); @@ -550,10 +690,11 @@ __global__ void Marlin( // Compute all information about the current slice which is required for // synchronization. - auto init_slice = [&](bool first_init = false) { + bool first_init = true; + auto init_part2_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters < 0 || slice_col_par >= part2_mn_tiles) slice_iters = 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; @@ -571,7 +712,7 @@ __global__ void Marlin( if (col_off > 0) slice_idx--; } } - if (parallel * n_tiles >= gridDim.x) { + if (part2_mn_tiles >= gridDim.x) { if (slice_count > 1 && slice_idx == slice_count - 1) { locks_off++; } @@ -605,25 +746,61 @@ __global__ void Marlin( par_id++; update_next_moe_block_data(); } + if (is_a_8bit && (first_init || slice_col == 0)) { + __syncthreads(); + cp_async1_ca_pred(&sh_a_s[threadIdx.x], + &a_scales_ptr[sh_rd_block_sorted_ids[threadIdx.x]], + threadIdx.x < block_num_valid_tokens); + } }; - update_next_moe_block_data(); - init_slice(true); + auto init_part1_slice = [&]() { + if (part1_mn_iters) { + part1_mn_iters--; + par_id = slice_col_par / n_tiles; + slice_col = slice_col_par % n_tiles; + slice_iters = k_tiles; + update_next_moe_block_data(); + if (is_a_8bit) { + __syncthreads(); + cp_async1_ca_pred(&sh_a_s[threadIdx.x], + &a_scales_ptr[sh_rd_block_sorted_ids[threadIdx.x]], + threadIdx.x < block_num_valid_tokens); + } + } + }; + + auto init_slice = [&]() { + if (!in_part2 && !part1_mn_iters) { + in_part2 = true; + slice_col_par = (iters * blockIdx.x) / k_tiles; + slice_row = (iters * blockIdx.x) % k_tiles; + slice_col = (slice_col_par + global_mn_tiles - part2_mn_tiles) % n_tiles; + par_id = (slice_col_par + global_mn_tiles - part2_mn_tiles) / n_tiles; + update_next_moe_block_data(); + } + if (!in_part2) { + init_part1_slice(); + } else { + init_part2_slice(); + first_init = false; + } + }; + + init_slice(); // A sizes/strides // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; + int a_gl_stride = prob_k / (is_a_8bit ? 16 : 8); // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + constexpr int a_sh_stride = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // between subsequent accesses within a tile int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // within a shared memory tile constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // overall size of a tile @@ -632,24 +809,25 @@ __global__ void Marlin( constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + int b_gl_stride = 16 * prob_n / (pack_factor * (is_a_8bit ? 2 : 4)); + constexpr int b_sh_stride = + ((thread_n_blocks * 16) * 16 / pack_factor) / (is_a_8bit ? 2 : 4); + constexpr int b_thread_vecs = b_type.size_bits() == 4 ? 1 : 2; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_stage = + b_sh_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = + 16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8); constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) + ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -662,7 +840,8 @@ __global__ void Marlin( constexpr int act_s_max_num_groups = 32; int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; + + constexpr int tb_n_warps = thread_n_blocks / (is_a_8bit ? 2 : 4); int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Zero-points sizes/strides @@ -677,7 +856,6 @@ __global__ void Marlin( // Global A read index of current thread. int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o; int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; - // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); @@ -685,17 +863,22 @@ __global__ void Marlin( int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + a_sh_rd += 2 * ((threadIdx.x / 32) / tb_n_warps) * b_sh_wr_iters; - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; + int b_gl_rd; + if (threads <= b_sh_stride) { + b_gl_rd = threadIdx.x; + } else { + b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + } + + b_gl_rd += B_expert_off + b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x * b_thread_vecs; auto b_sh_rd = threadIdx.x * b_thread_vecs; + b_sh_rd += b_sh_rd / b_sh_stride * (b_sh_stride * (b_sh_wr_iters - 1)); // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; int slice_k_start = tb_k * slice_row; int slice_k_finish = slice_k_start + tb_k * slice_iters; int slice_k_start_shared_fetch = slice_k_start; @@ -706,58 +889,54 @@ __global__ void Marlin( if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / - (w_type == vllm::kFE2M1f ? 2 : 1) + + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; } } auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + bool s_sh_wr_pred = threadIdx.x < s_sh_stage; // Zero-points int zp_gl_rd; if constexpr (has_zp) { if constexpr (group_blocks == -1) { zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { + } else if constexpr (group_blocks >= thread_k_blocks) { zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; } } auto zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + bool zp_sh_wr_pred = zp_sh_stage > 0 && threadIdx.x < zp_sh_stage; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - + if constexpr (is_a_8bit) { + s_sh_rd = 4 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 4); } else if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; int bias_sh_rd; if constexpr (m_block_size_8) { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + bias_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; } else { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + bias_sh_rd = (is_a_8bit ? 4 : 8) * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; } @@ -773,12 +952,16 @@ __global__ void Marlin( if constexpr (has_zp) { if constexpr (is_zp_float) { if constexpr (group_blocks != -1) { - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + zp_sh_rd = + 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; } + } else if (is_a_8bit) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % tb_n_warps / 2) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } else { zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + ((threadIdx.x / 32) % tb_n_warps) + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } } @@ -805,18 +988,13 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[i][j] = transform_a(2 * i + a_sh_rd_delta_i * j + a_sh_rd); } // Since B-accesses have non-constant stride they have to be computed at // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; // Shared memory storage for global fetch pipelines. constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; @@ -845,19 +1023,12 @@ __global__ void Marlin( static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); int4* sh_a = sh_s + sh_s_size; - constexpr int shm_size_used = moe_block_size + - stages * (g_idx_stage + zp_sh_stage) + - sh_s_size + sh_b_red_bias_size; - - // all remaining shared memory is used to cache A (input) - // sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` - int sh_a_max_row = - ((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; + FragC frag_c[thread_m_blocks][is_a_8bit ? 2 : 4][2]; + FragC frag_c_tmp[thread_m_blocks][is_a_8bit ? 2 : 4][2]; FragS frag_s[2][4]; // No act-order FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order @@ -865,6 +1036,24 @@ __global__ void Marlin( FragZP frag_zp; // Zero-points in fp16 FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + if constexpr (is_a_8bit && group_blocks != -1) { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } + } + // Zero accumulators. auto zero_accums = [&]() { #pragma unroll @@ -908,43 +1097,36 @@ __global__ void Marlin( } } }; - // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. - bool should_load_a = true; - int max_num_stage_groups = - ((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages; - max_num_stage_groups = max(max_num_stage_groups, 1); - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true, - int pipe_a = 0) { + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - if (should_load_a) { - int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe; #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; - int64_t sorted_row = 0; - if (!m_block_size_8 || row < 8) - sorted_row = sh_rd_block_sorted_ids[row]; - int64_t true_idx = - sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off; - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], - row < block_num_valid_tokens); - } + for (int i = 0; i < a_sh_wr_iters; i++) { + int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; + int64_t sorted_row = 0; + if (!m_block_size_8 || row < 8) + sorted_row = sh_rd_block_sorted_ids[row]; + int64_t true_idx = + sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off; + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], + row < block_num_valid_tokens); } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], - B_ptr[i] + j + B_expert_off); - } + for (int i = 0; i < (b_sh_wr_iters * b_thread_vecs); i++) { + constexpr int count = div_ceil(b_sh_stride, threads); + int b_gl_idx = + b_gl_rd + (i % count) * threads + + b_gl_stride * (i / count) * div_ceil(threads, b_sh_stride); - B_ptr[i] += b_gl_rd_delta_o; + cp_async4(&sh_b_stage[threads * i + threadIdx.x], &B[b_gl_idx]); } + b_gl_rd += b_gl_rd_delta_o; + if constexpr (has_act_order) { // Fetch g_idx thread-block portion int full_pipe = a_off; @@ -964,44 +1146,24 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + // Only fetch scales if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta * s_tb_groups; } } if constexpr (has_zp && group_blocks != -1) { int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; + // Only fetch zero points if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta * zp_tb_groups; } } } @@ -1035,18 +1197,18 @@ __global__ void Marlin( // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) { - int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm( + ldsm( frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + &sh_b_stage[b_sh_stride * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; @@ -1070,53 +1232,54 @@ __global__ void Marlin( auto fetch_scales_to_registers = [&](int k, int full_pipe) { int pipe = full_pipe % stages; + using IT1 = typename std::conditional_t; + using IT0 = typename std::conditional_t; + constexpr int group_blocks2 = div_ceil(group_blocks, is_a_8bit ? 2 : 1); if constexpr (!has_act_order) { // No act-order case if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 && dequant_skip_flop) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - reinterpret_cast(&frag_s[1])[0] = - reinterpret_cast(&frag_s[0])[0]; + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } } - } else { + } else if (group_blocks2 < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / tb_n_warps; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = - k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / group_blocks2; int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (w_type_id != vllm::kFE2M1f.id()) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast( - sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } else { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( - sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + - k % 2]; + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } else if (group_blocks >= b_sh_wr_iters) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; } } } @@ -1137,18 +1300,15 @@ __global__ void Marlin( cur_k = 0; // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); + cur_k += k % b_sh_wr_iters; // Determine "position" inside the thread-block (based on warp and // thread-id) auto warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + int warp_row = warp_id / tb_n_warps; + int warp_col = warp_id % tb_n_warps; - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; + cur_k += warp_row * 16 * b_sh_wr_iters; auto th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix @@ -1203,18 +1363,16 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 || is_a_8bit) { #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; } } - } else if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0 || is_a_8bit) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = @@ -1223,21 +1381,11 @@ __global__ void Marlin( } } else { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int warp_row = warp_id / tb_n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / div_ceil(group_blocks, is_a_8bit ? 2 : 1); int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1256,29 +1404,18 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + - zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; } - } else { + } else if (group_blocks < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero + int warp_row = warp_id / tb_n_warps; + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; int cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1289,33 +1426,46 @@ __global__ void Marlin( } }; - auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - dequant(q, frag_b_ptr); + auto dequant_data = [&](int q, scalar_32bit_t* frag_b_ptr, int zp = 0) { + if constexpr (a_type.size_bits() != b_type.size_bits()) { + if constexpr (is_a_8bit && has_zp) { + sub_zp_and_dequant( + q, frag_b_ptr, zp); + } else { + dequant(q, frag_b_ptr); + } + } }; // Execute the actual tensor core matmul of a sub-tile. bool is_first_matmul_in_slice = true; - auto matmul = [&](int k) { + auto matmul = [&](int k, int pipe) { + if (is_a_8bit) return; int k2 = k % 2; + constexpr int g = + group_blocks > 0 ? div_ceil(group_blocks, thread_k_blocks) : 1; const bool is_new_zp = - ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == 0) || + ((group_blocks > 0) && (group_blocks < b_sh_wr_iters || k == 0)) && + (pipe % g == 0) || (group_blocks == -1 && is_first_matmul_in_slice); if constexpr (has_zp && !is_zp_float) { if (is_new_zp) { if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; int zp_quant_0, zp_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (b_type.size_bits() == 4) { zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = zp_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = frag_qzp[k2][1]; } - dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); - dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, + reinterpret_cast(&frag_zp) + 2); } } if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { @@ -1325,14 +1475,14 @@ __global__ void Marlin( } } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales( - s_quant_0, reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( - s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } // We have the m dimension as the inner loop in order to encourage overlapping @@ -1343,61 +1493,168 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type_id == vllm::kFE2M1f.id()) { + if constexpr (b_type_id == vllm::kFE2M1f.id()) { b_quant_1 = frag_b_quant[k2][0][j]; b_quant_0 = b_quant_1 << 8; - } else if constexpr (w_type.size_bits() == 4) { + } else if constexpr (b_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; } - dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); - dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); - if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { - sub_zp(frag_b0, frag_zp[j], 0); - sub_zp(frag_b1, frag_zp[j], 1); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float && !is_a_8bit) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); } // Apply scale to frag_b0 - if constexpr (has_act_order) { + if constexpr (has_act_order && !is_a_8bit) { static_assert(group_blocks != -1); - scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); - scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && - group_blocks == -1) { + group_blocks == -1 && !is_a_8bit) { int idx = (threadIdx.x / 4) % 2; - scalar_t2 s2 = Dtype::nums2num2( + scalar_t2 s2 = Adtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); - scale_and_sub(frag_b0, s2.x, frag_zp[j].x); - scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 && + !is_a_8bit) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); - scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); - } else if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k2][j], 0); - scale(frag_b1, frag_s[k2][j], 1); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1 && !is_a_8bit) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { if constexpr (m_block_size_8) { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + mma_trans(frag_a[k2][i], frag_b0, frag_b1, + frag_c[i][j][0]); } else { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + auto matmul_a8 = [&](int k) { + int k2 = k % 2; + #pragma unroll + for (int j = 0; j < 2; j++) { + FragB frag_b[2]; + + if (is_a_8bit && b_type.size_bits() == 4 && !has_zp) { + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b)); + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2); + } else if (is_a_8bit && b_type.size_bits() == 4 && has_zp) { + int off = (threadIdx.x / 32) % 2 * 2 + j; + int zp = (frag_qzp[k2][0] >> (off * 8)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b), zp); + zp = (frag_qzp[k2][0] >> (off * 8 + 4)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2, zp); + } else { + reinterpret_cast(&frag_b)[0] = + reinterpret_cast(&frag_b_quant[k2][j])[0]; + reinterpret_cast(&frag_b)[1] = + reinterpret_cast(&frag_b_quant[k2][j])[1]; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k2][i], frag_b[0], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); + mma(frag_a[k2][i], frag_b[1], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); + } + + if constexpr (group_blocks != -1) { + if (group_blocks == 2 || k == 1) { + if constexpr (a_type == vllm::kS8) { + int2 s_vals[2]; + s_vals[0] = { + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[1]}; + s_vals[1] = { + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[1]}; + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[0])[g % 2]; + *reinterpret_cast(&frag_c[i][j][0][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][0][g]) * + scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[1])[g % 2]; + *reinterpret_cast(&frag_c[i][j][1][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][1][g]) * + scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } else { + float2 s_vals[2]; + if constexpr (s_type_id != vllm::kFE8M0fnu.id()) { + static_assert(a_type.size_bits() == 16 || + s_type.size_bits() == 16); + s_vals[0] = Cdtype::num22float2(frag_s[k2][j * 2][0]); + s_vals[1] = Cdtype::num22float2(frag_s[k2][j * 2 + 1][0]); + } else { + int32_t* s_vals_int = reinterpret_cast(&s_vals[0]); + int32_t s_vals_e8m0 = + *reinterpret_cast(&frag_s[k2][j][0]); + + s_vals_int[0] = (s_vals_e8m0 & 0xFF) << 23; + s_vals_int[1] = (s_vals_e8m0 & 0xFF00) << 15; + s_vals_int[2] = (s_vals_e8m0 & 0xFF0000) << 7; + s_vals_int[3] = (s_vals_e8m0 & 0xFF000000) >> 1; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[0])[g % 2]; + frag_c[i][j][0][g] += frag_c_tmp[i][j][0][g] * scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[1])[g % 2]; + frag_c[i][j][1][g] += frag_c_tmp[i][j][1][g] * scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } } } } @@ -1411,7 +1668,8 @@ __global__ void Marlin( constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_stride = + b_sh_stride_threads * (is_a_8bit ? 2 : 4) * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); @@ -1426,7 +1684,8 @@ __global__ void Marlin( for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { #pragma unroll - for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + for (int j = 0; j < (is_a_8bit ? 2 : 4) * 2; + j += (m_block_size_8 ? 2 : 1)) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { @@ -1435,24 +1694,26 @@ __global__ void Marlin( float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } - sh_red[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + sh_red[red_sh_wr] = reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { #pragma unroll - for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + for (int i = 0; i < (is_a_8bit ? 2 : 4) * 2; + i += (m_block_size_8 ? 2 : 1)) { float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + i][j] += c_rd[j]; } } __syncthreads(); @@ -1468,13 +1729,13 @@ __global__ void Marlin( // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; bool is_th_active = threadIdx.x < active_threads; if (!is_th_active) { return; } - int c_gl_stride = prob_n / 8; + int c_gl_stride = prob_n / 8 * (is_a_8bit ? 2 : 1); int c_gl_wr_delta_o = 8 * c_gl_stride; int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr; @@ -1485,7 +1746,7 @@ __global__ void Marlin( } else { c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; + c_gl_wr += (2 * thread_n_blocks) * slice_col * (is_a_8bit ? 2 : 1); } constexpr int c_sh_wr_delta = active_threads; int c_sh_wr = threadIdx.x; @@ -1504,7 +1765,13 @@ __global__ void Marlin( if (c_idx / c_gl_stride < block_num_valid_tokens) { int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + if constexpr (is_a_8bit) { + int2* sh_red_int2 = reinterpret_cast(sh_red); + int2* c_int2 = reinterpret_cast(C); + sh_red_int2[c_sh_wr + c_sh_wr_delta * i] = c_int2[true_idx]; + } else { + sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + } } } } @@ -1512,29 +1779,37 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { if (!first) { - int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_scalar_t* c_red_f16; + if constexpr (is_a_8bit) { + int2 tmp = + reinterpret_cast(sh_red)[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } else { + int4 tmp = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + (i % 4) + + delta] += Cdtype::num2float(c_red_f16[j]); } } if (!last) { - int4 c; + c_scalar_t c_f16[is_a_8bit ? 4 : 8]; #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + c_f16[j] = Cdtype::float2num(reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + (i % 4) + + delta]); } int c_idx; @@ -1547,7 +1822,12 @@ __global__ void Marlin( if (c_idx / c_gl_stride < block_num_valid_tokens) { int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - C[true_idx] = c; + if constexpr (is_a_8bit) { + int2* c_int2 = reinterpret_cast(C); + c_int2[true_idx] = *reinterpret_cast(c_f16); + } else { + C[true_idx] = *reinterpret_cast(c_f16); + } } } } @@ -1561,10 +1841,10 @@ __global__ void Marlin( constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; bool is_th_active = threadIdx.x < active_threads; - constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int num_floats = thread_m_blocks * (is_a_8bit ? 2 : 4) * 2 * 4; constexpr int th_size = num_floats * sizeof(float) / 16; int c_cur_offset = locks_off * c_size; @@ -1632,7 +1912,7 @@ __global__ void Marlin( } else { c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); + c_sh_wr += (is_a_8bit ? 16 : 32) * (threadIdx.x / 32); } int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + @@ -1641,49 +1921,49 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + c_scalar_t2 res = + Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && + if constexpr (!has_act_order && group_blocks == -1 && !is_a_8bit && + b_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - scalar_t2 tmp_scale = s[0]; + c_scalar_t2 tmp_scale = s[0]; if constexpr (m_block_size_8) { - tmp_scale = Dtype::num2num2( + tmp_scale = Cdtype::num2num2( reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); } res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { if (!mul_topk_weights) { res = __hmul2(res, global_scale); } } if (has_bias && last) { - scalar_t2 tmp_bias = b_bias[0]; + c_scalar_t2 tmp_bias = b_bias[0]; if constexpr (m_block_size_8) { - tmp_bias = Dtype::num2num2( + tmp_bias = Cdtype::num2num2( reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); } res = __hadd2(res, tmp_bias); } if constexpr (m_block_size_8) { - ((scalar_t*)sh_red)[idx] = res.x; - ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + ((c_scalar_t*)sh_red)[idx] = res.x; + ((c_scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; } else { - ((scalar_t2*)sh_red)[idx] = res; + ((c_scalar_t2*)sh_red)[idx] = res; } }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll - for (int j = 0; j < 4; j++) { + for (int j = 0; j < (is_a_8bit ? 2 : 4); j++) { if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], @@ -1721,24 +2001,26 @@ __global__ void Marlin( if (row < block_num_valid_tokens) { int64_t sorted_row = sh_block_sorted_ids[row]; int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride; - scalar_t2 topk_weight_score; + c_scalar_t2 topk_weight_score; if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row]; if (use_atomic_add && slice_count > 1 || mul_topk_weights) { - scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); - scalar_t2* sh_red_half2 = - reinterpret_cast(&sh_red[c_sh_rd]); + c_scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); + c_scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); + if (mul_topk_weights) { #pragma unroll - for (int a = 0; a < 4; a++) { - scalar_t2 res = sh_red_half2[a]; - if (mul_topk_weights) { - res = __hmul2(res, topk_weight_score); + for (int a = 0; a < 4; a++) { + sh_red_half2[a] = __hmul2(sh_red_half2[a], topk_weight_score); } + } - if (use_atomic_add && slice_count > 1) { - atomicAdd(&C_half2[a], res); - } else { - C_half2[a] = res; - }; + if (use_atomic_add && slice_count > 1) { + #pragma unroll + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[true_idx] = *reinterpret_cast(sh_red_half2); } } else { C[true_idx] = sh_red[c_sh_rd]; @@ -1772,7 +2054,7 @@ __global__ void Marlin( } } } - fetch_to_shared(i, i, i < slice_iters, i); + fetch_to_shared(i, i, i < slice_iters); } zero_accums(); @@ -1797,73 +2079,100 @@ __global__ void Marlin( // have even length meaning that the next iteration will always start at // index 0. - for (int stage_group_id = 0; stage_group_id < max_num_stage_groups; - stage_group_id++) { #pragma unroll - for (int pipe = 0; pipe < stages;) { + for (int pipe = 0; pipe < stages;) { #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - int idx = - (pipe >= stages && stage_group_id == max_num_stage_groups - 1) - ? (pipe - stages) - : (pipe + stage_group_id * stages); - fetch_to_registers(k + 1, pipe % stages, idx); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1) - ? (pipe - 1) - : (pipe + (stage_group_id + 1) * stages - 1); - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages, idx); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd_col += a_gl_rd_delta_o * stages; - - if constexpr (has_act_order) { - slice_k_start += tb_k * stages; - - if (slice_k_start < prob_k) { - slice_k_start_shared_fetch += tb_k * stages; - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_act_order_scales_to_shared(false, first_group_id, - last_group_id); - __syncthreads(); - } + + if constexpr (!is_a_8bit) { + matmul(k, pipe - (k >= b_sh_wr_iters - 2 ? 1 : 0)); + } else { + static_assert(group_blocks != 0 && group_blocks != 1); + matmul_a8(k); } } + slice_iters--; if (slice_iters == 0) { break; } } + a_gl_rd_col += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } + } + } + // Process results and, if necessary, proceed to the next column slice. // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { + if constexpr (is_a_8bit) { + float frag_a_s[2 * thread_m_blocks]; + + for (int i = 0; i < 2 * thread_m_blocks; i++) + frag_a_s[i] = sh_a_s[i * 8 + (threadIdx.x % 32) / 4]; + + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][0][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][0][g] = c_val * s_val; + } + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][1][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][1][g] = c_val * s_val; + } + } + } + } + cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (b_type.size_bits() == 8 || (last || use_atomic_add) || is_a_8bit) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } @@ -1881,20 +2190,27 @@ __global__ void Marlin( } if constexpr (!has_act_order && group_blocks == -1 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + (has_zp && dequant_skip_flop || !has_zp || is_a_8bit)) { + if constexpr (is_a_8bit) { cp_async_wait<0>(); __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + } + } else if (b_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < tb_n_warps) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; if constexpr (m_block_size_8) { int idx = (threadIdx.x / 4) % 2; - scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + c_scalar_t2* frag_s_half2 = + reinterpret_cast(frag_s); #pragma unroll for (int i = 0; i < 8; i++) { - frag_s_half2[i] = Dtype::num2num2( - reinterpret_cast(&frag_s_half2[i])[idx]); + frag_s_half2[i] = Cdtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); } } } @@ -1904,26 +2220,48 @@ __global__ void Marlin( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if constexpr (!has_act_order && group_blocks == -1 && is_a_8bit) { + #pragma unroll + for (int j = 0; j < 2; j++) { + float2 aa[2]; + aa[0] = Cdtype::num22float2(frag_s[0][j * 2][0]); + aa[1] = Cdtype::num22float2(frag_s[0][j * 2 + 1][0]); + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[0])[g % 2]; + frag_c[i][j][0][g] *= scale; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[1])[g % 2]; + frag_c[i][j][1][g] *= scale; + } + } + } + } else if (!has_act_order && group_blocks == -1 && + b_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); if constexpr (!m_block_size_8) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } @@ -1947,7 +2285,8 @@ __global__ void Marlin( cp_async_wait<0>(); __syncthreads(); reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; - reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + if constexpr (!is_a_8bit) + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; __syncthreads(); } @@ -1956,37 +2295,22 @@ __global__ void Marlin( if (last || use_atomic_add) // only the last block in a slice actually writes the result write_result(last); - int old_slice_row = slice_row; slice_row = 0; - slice_col_par++; - slice_col++; + if (!in_part2) { + slice_col_par += gridDim.x; + } else { + slice_col_par++; + slice_col++; + } is_first_matmul_in_slice = true; init_slice(); - // Should we load A matrix in next slice? - // `slice_col == 0`: when move to a new moe block - // `old_slice_row > 0`: - // when the last slice is not starting from k_index == 0 - // (only happen when it is the first slice of a threadblock) - // `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`: - // when the required shared memory size is larger than - // the remaining shared memory - if (slice_col == 0 || old_slice_row || - prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) { - should_load_a = true; - } else { - should_load_a = false; - } - if (slice_iters) { - a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } + a_gl_rd_col = + a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; + b_gl_rd = B_expert_off + b_gl_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col + b_gl_rd_delta_o * slice_row; bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading @@ -1996,8 +2320,26 @@ __global__ void Marlin( slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } } start_pipes(); } diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 601e2aa6f9913..27b6ffaa67176 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -37,39 +37,6 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template -__global__ void permute_cols_kernel( - int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, - const int32_t* __restrict__ sorted_token_ids_ptr, - const int32_t* __restrict__ expert_ids_ptr, - const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, - int size_k, int top_k) {}; - -} // namespace marlin - -torch::Tensor moe_wna16_marlin_gemm( - torch::Tensor& a, std::optional c_or_none, - torch::Tensor& b_q_weight, - std::optional const& b_bias_or_none, torch::Tensor& b_scales, - std::optional const& b_zeros_or_none, - std::optional const& g_idx_or_none, - std::optional const& perm_or_none, torch::Tensor& workspace, - torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, - torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, - int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { - TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. template @@ -207,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int has_zp, - int is_zp_float) { + int is_zp_float, bool is_a_8bit) { int pack_factor = 32 / num_bits; // Get B size @@ -217,8 +184,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) - int sh_block_meta_size = tb_m * 4; - int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_block_meta_size = tb_m * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2); int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_bias_size = tb_n * 2; @@ -250,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int has_zp, int is_zp_float, - int max_shared_mem) { + int max_shared_mem, bool is_a_8bit) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -273,188 +240,34 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, } // Check that pipeline fits into cache - int cache_size = get_kernel_cache_size( - th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size + 512 <= max_shared_mem; + int cache_size = + get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, is_a_8bit); + return cache_size <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - constexpr auto S_TYPE = \ - W_TYPE == vllm::kFE2M1f \ - ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ - : (std::is_same::value ? vllm::kFloat16 \ - : vllm::kBFloat16); \ - kernel = Marlin; \ - } - - // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) - // this is the most common cases - // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) - // FZP: cases for float-zero-point (is_zp_float = true) - // ACT: cases for act order case (group_blocks == 0) - // FP4: cases for nvfp4(e2m1) (group_blocks == 1) - #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) - - #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) - - #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF(W_TYPE) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) - - #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF(W_TYPE) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) - -template -MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, - int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool m_block_size_8, - bool has_act_order, bool has_zp, - int group_blocks, int num_threads, - bool is_zp_float) { - int num_bits = q_type.size_bits(); +MarlinFuncPtr get_marlin_kernel( + const vllm::ScalarType a_type, const vllm::ScalarType b_type, + const vllm::ScalarType c_type, const vllm::ScalarType s_type, + int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, + bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, + int threads, bool is_zp_float) { + int num_bits = b_type.size_bits(); auto kernel = MarlinDefault; - if (false) { - } - COMMON_GET_IF(vllm::kU4) - COMMON_GET_IF(vllm::kU4B8) - COMMON_GET_IF(vllm::kU8B128) - - NVFP4_GET_IF(vllm::kFE2M1f) - - BIGGROUP_GET_IF(vllm::kFE4M3fn) - - ACT_GET_IF(vllm::kU4B8) - ACT_GET_IF(vllm::kU8B128) - if (std::is_same::value) { - if (false) { - } - MXFP4_GET_IF(vllm::kFE2M1f) - } +#include "kernel_selector.h" return kernel; } -template -exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, - int prob_n, int prob_k, int thread_m_blocks, - bool m_block_size_8, int num_bits, - int group_size, bool has_act_order, - bool is_k_full, bool has_zp, - bool is_zp_float, int max_shared_mem) { +exec_config_t determine_exec_config( + const vllm::ScalarType& a_type, const vllm::ScalarType& b_type, + const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m, + int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks, + bool m_block_size_8, int num_bits, int group_size, bool has_act_order, + bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms, + bool is_a_8bit) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs @@ -471,73 +284,69 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, max_shared_mem)) { + is_k_full, has_zp, is_zp_float, max_shared_mem - 512, + is_a_8bit)) { continue; } int cache_size = get_kernel_cache_size( th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, + is_a_8bit); int group_blocks = 0; if (!has_act_order) { group_blocks = group_size == -1 ? -1 : (group_size / 16); } - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, th_config.thread_n / 16, - th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, - group_blocks, th_config.num_threads, is_zp_float); + auto kernel = + get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, + th_config.thread_n / 16, th_config.thread_k / 16, + m_block_size_8, has_act_order, has_zp, group_blocks, + th_config.num_threads, is_zp_float); if (kernel == MarlinDefault) continue; - if (thread_m_blocks > 1) { - exec_cfg = {1, th_config}; - break; - } else { - cudaFuncAttributes attr; - cudaFuncGetAttributes(&attr, kernel); - int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; - int allow_count = min(device_max_reg_size / reg_size, - max_shared_mem / (cache_size + 1024)); + cudaFuncAttributes attr; + cudaFuncGetAttributes(&attr, kernel); + int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; + int allow_count = min(device_max_reg_size / reg_size, + max_shared_mem / (cache_size + 1536)); + if (thread_m_blocks == 1) allow_count = max(min(allow_count, 4), 1); - if (allow_count > count) { - count = allow_count; - exec_cfg = {count, th_config}; - }; + else + allow_count = max(min(allow_count, 2), 1); + + if (prob_n / th_config.thread_n * prob_m * top_k * 4 < sms * allow_count) { + allow_count = + max(prob_n / th_config.thread_n * prob_m * top_k * 4 / sms, 1); } + + if (allow_count > count) { + count = allow_count; + exec_cfg = {count, th_config}; + }; } return exec_cfg; } -template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, - void* s, void* s2, void* zp, void* g_idx, void* perm, - void* a_tmp, void* sorted_token_ids, void* expert_ids, - void* num_tokens_past_padded, void* topk_weights, - int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, - int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_bias, - bool has_act_order, bool is_k_full, bool has_zp, int num_groups, - int group_size, int dev, cudaStream_t stream, int thread_k, - int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { + void* a_s, void* b_s, void* g_s, void* zp, void* g_idx, + void* perm, void* a_tmp, void* sorted_token_ids, + void* expert_ids, void* num_tokens_past_padded, + void* topk_weights, int moe_block_size, int num_experts, + int top_k, bool mul_topk_weights, bool is_ep, int prob_m, + int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& a_type, vllm::ScalarType const& b_type, + vllm::ScalarType const& c_type, vllm::ScalarType const& s_type, + bool has_bias, bool has_act_order, bool is_k_full, bool has_zp, + int num_groups, int group_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int blocks_per_sm, + bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { int thread_m_blocks = div_ceil(moe_block_size, 16); bool m_block_size_8 = moe_block_size == 8; - - if (has_zp) { - TORCH_CHECK( - q_type == vllm::kU4 || q_type == vllm::kU8, - "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); - } else { - TORCH_CHECK( - q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, - "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - q_type.str()); - } + bool is_a_8bit = a_type.size_bits() == 8; TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -563,14 +372,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, } } - int num_bits = q_type.size_bits(); + int num_bits = b_type.size_bits(); const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* bias_ptr = (const int4*)b_bias; - const int4* s_ptr = (const int4*)s; - const uint16_t* s2_ptr = (const uint16_t*)s2; + const float* a_s_ptr = (const float*)a_s; + const int4* b_s_ptr = (const int4*)b_s; + const uint16_t* g_s_ptr = (const uint16_t*)g_s; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -618,22 +428,41 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); + int major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + dev); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + dev); + TORCH_CHECK(major_capability * 10 + minor_capability >= 80, + "marlin kernel only support Ampere or newer GPUs."); + if (a_type == vllm::kFE4M3fn) { + TORCH_CHECK(major_capability * 10 + minor_capability >= 89, + "FP8 only support Ada Lovelace or newer GPUs."); + TORCH_CHECK( + major_capability * 10 + minor_capability == 89 || + major_capability * 10 + minor_capability == 120, + "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " + "Marlin W4A16 on other devices)."); + } + // Set thread config exec_config_t exec_cfg; thread_config_t thread_tfg; if (thread_k != -1 && thread_n != -1) { - thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; - exec_cfg = exec_config_t{1, thread_tfg}; + thread_tfg = thread_config_t{thread_k, thread_n, thread_k * thread_n / 64}; + if (blocks_per_sm == -1) blocks_per_sm = 1; + exec_cfg = exec_config_t{blocks_per_sm, thread_tfg}; TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); } else { // Auto config - exec_cfg = determine_exec_config( - q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, - max_shared_mem); + exec_cfg = determine_exec_config( + a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts, + top_k, thread_m_blocks, m_block_size_8, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms, + is_a_8bit); thread_tfg = exec_cfg.tb_cfg; } @@ -647,22 +476,29 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; - TORCH_CHECK( - is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, - prob_n, prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, max_shared_mem), - "Invalid thread config: thread_m_blocks = ", thread_m_blocks, - ", thread_k = ", thread_tfg.thread_k, - ", thread_n = ", thread_tfg.thread_n, - ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ", - prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, ", has_act_order = ", has_act_order, - ", is_k_full = ", is_k_full, ", has_zp = ", has_zp, - ", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem); + TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem, is_a_8bit), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, + ", max_shared_mem = ", max_shared_mem); - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, - has_act_order, has_zp, group_blocks, num_threads, is_zp_float); + int sh_cache_size = + get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, is_a_8bit); + + auto kernel = get_marlin_kernel( + a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, + thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, + num_threads, is_zp_float); if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, @@ -679,19 +515,20 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, - prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem); + prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce); // clang-format on } } // namespace MARLIN_NAMESPACE_NAME torch::Tensor moe_wna16_marlin_gemm( - torch::Tensor& a, std::optional const& c_or_none, + torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, std::optional const& b_bias_or_none, torch::Tensor& b_scales, + std::optional const& a_scales_or_none, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, @@ -699,11 +536,70 @@ torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { - vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - int pack_factor = 32 / b_q_type.size_bits(); + bool is_zp_float, int64_t thread_k, int64_t thread_n, + int64_t blocks_per_sm) { + vllm::ScalarTypeId a_type_id, c_type_id, s_type_id; + + auto c_dtype = a.dtype(); + if (a.scalar_type() == at::ScalarType::Half) { + a_type_id = vllm::kFloat16.id(); + c_type_id = vllm::kFloat16.id(); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + a_type_id = vllm::kBFloat16.id(); + c_type_id = vllm::kBFloat16.id(); + } else { + c_dtype = b_scales.dtype(); + if (b_scales.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (b_scales.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + c_type_id = vllm::kBFloat16.id(); + + TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4"); + torch::Tensor c = c_or_none.value(); + c_dtype = c.dtype(); + + if (c.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (c.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + TORCH_CHECK(false, "unsupported c dtype"); + } + } + + if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) { + a_type_id = vllm::kFE4M3fn.id(); + } else if (a.scalar_type() == at::ScalarType::Char) { + a_type_id = vllm::kS8.id(); + } else { + TORCH_CHECK(false, "unsupported `a` scalar_type"); + } + } + + s_type_id = c_type_id; + if (b_type_id == vllm::kFE2M1f.id()) { + if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) { + s_type_id = vllm::kFE4M3fn.id(); + } else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + s_type_id = vllm::kFE8M0fnu.id(); + } else { + TORCH_CHECK(false, + "When b_type = float4_e2m1f, b_scale scalar type must be", + "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); + } + } + + vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); + vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id); + vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id); + vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id); + + int pack_factor = 32 / b_type.size_bits(); + int num_experts = b_q_weight.size(0); if (moe_block_size != 8) { TORCH_CHECK(moe_block_size % 16 == 0, @@ -745,19 +641,27 @@ torch::Tensor moe_wna16_marlin_gemm( TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; + torch::Tensor a_scales; + auto options = torch::TensorOptions().dtype(c_dtype).device(a.device()); + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + + if (a_scales_or_none.has_value()) { + a_scales = a_scales_or_none.value(); + TORCH_CHECK(a_type.size_bits() == 8, + "a_scales can only be used for 8bit activation."); + } else { + a_scales = torch::empty({0}, options_fp32); + TORCH_CHECK(a_type.size_bits() != 8, + "the a_scales parameter must be passed for 8bit activation."); + } + // sms: number of SMs to use for the kernel int sms = -1; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c; if (c_or_none.has_value()) { c = c_or_none.value(); @@ -774,8 +678,6 @@ torch::Tensor moe_wna16_marlin_gemm( // Alloc C tmp buffer that is going to be used for the global reduce torch::Tensor c_tmp; - auto options_fp32 = - torch::TensorOptions().dtype(at::kFloat).device(a.device()); if (use_fp32_reduce && !use_atomic_add) { // max num of threadblocks is sms * 4 long max_c_tmp_size = min( @@ -846,11 +748,11 @@ torch::Tensor moe_wna16_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), "the global_scale parameter must be passed for nvfp4 format."); } @@ -877,15 +779,15 @@ torch::Tensor moe_wna16_marlin_gemm( bool has_zp = b_zeros.size(-1) > 0; if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4 || b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + b_type == vllm::kU4 || b_type == vllm::kU8, + "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str()); } else { - TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, - "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " - "float4_e2m1f when " - "has_zp = False. Got = ", - b_q_type.str()); + TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f, + "b_type must be uint4b8, uint8b128, int4, int8, " + "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ", + b_type.str()); } if (has_zp && is_zp_float) { @@ -929,71 +831,33 @@ torch::Tensor moe_wna16_marlin_gemm( " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - MARLIN_NAMESPACE_NAME::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), - sorted_token_ids.data_ptr(), expert_ids.data_ptr(), - num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), - moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, - workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - - MARLIN_NAMESPACE_NAME::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - sorted_token_ids.data_ptr(), expert_ids.data_ptr(), - num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), - moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, - workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else { - TORCH_CHECK(false, - "moe_wna16_marlin_gemm only supports bfloat16 and float16"); + TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float, + "scalar type of a_scales must be float"); + TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(), + "scalar type of global_scale must be the same with c"); + if (a_type.size_bits() == 16) { + TORCH_CHECK( + a.scalar_type() == c.scalar_type(), + "scalar type of a must be the same with c for 16 bit activation"); } + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(), + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), + topk_weights.data_ptr(), moe_block_size, num_experts, top_k, + mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(), + a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full, + has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce, + is_zp_float); + return c; } -#endif - TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index bd95ade40a083..e0a8280722f3c 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -63,16 +63,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," "Tensor! b_q_weight, Tensor? b_bias_or_none," - "Tensor! b_scales, Tensor? global_scale, Tensor? " + "Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? " "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," "Tensor! expert_ids, Tensor! num_tokens_past_padded," "Tensor! topk_weights, int moe_block_size, int top_k, " - "bool mul_topk_weights, bool is_ep, int b_q_type_id," + "bool mul_topk_weights, bool is_ep, int b_type_id," "int size_m, int size_n, int size_k," "bool is_full_k, bool use_atomic_add," - "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + "bool use_fp32_reduce, bool is_zp_float," + "int thread_k, int thread_n, int blocks_per_sm) -> Tensor"); + m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " diff --git a/csrc/ops.h b/csrc/ops.h index f8bdc61aaa8ec..4bb7857b15032 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -52,14 +52,13 @@ void paged_attention_v2( const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); -#ifndef USE_ROCM void merge_attn_states(torch::Tensor& output, std::optional output_lse, const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse); - +#ifndef USE_ROCM void convert_vertical_slash_indexes( torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index 5b007e5ea3283..6744402783832 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -22,6 +22,7 @@ #include #include #include +#include "cutlass_extensions/common.hpp" #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" @@ -173,7 +174,7 @@ void run_get_group_gemm_starts( } template -void run_fp4_blockwise_scaled_group_mm( +void run_fp4_blockwise_scaled_group_mm_sm100( torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes, @@ -343,17 +344,225 @@ void run_fp4_blockwise_scaled_group_mm( auto can_implement_status = gemm_op.can_implement(args); TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, - "Failed to implement GEMM"); + "Failed to implement GEMM: status=", (int)can_implement_status); // Run the GEMM auto status = gemm_op.initialize(args, workspace.data_ptr()); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Failed to initialize GEMM: status=", (int)status, + " workspace_size=", workspace_size, " num_experts=", num_experts, + " M=", M, " N=", N, " K=", K); status = gemm_op.run(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); } +void run_fp4_blockwise_scaled_group_mm_sm120( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, + int N, int K) { + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementType = cutlass::float_e2m1_t; + using ElementSFType = cutlass::float_ue4m3_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + + // NOTE: For SM120 it seems templating the output type is not supported and + // we need to hardcode the output type to bfloat16 + using ElementC = cutlass::bfloat16_t; + using ElementD = ElementC; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + + // Alignment constraints + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _128>; + + using FusionOperation = cutlass::epilogue::fusion::LinearCombination< + ElementD, ElementAccumulator, ElementC, ElementAccumulator>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutD*, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, + LayoutB*, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using ScaleConfig = + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = static_cast(expert_offsets.size(0)); + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); + torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); + torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); + torch::Tensor c_strides1 = + torch::full({num_experts}, output.stride(0), options_int); + torch::Tensor a_strides1 = + torch::full({num_experts}, a.stride(0) * 2, options_int); + torch::Tensor b_strides1 = + torch::full({num_experts}, b.stride(1) * 2, options_int); + + run_get_group_gemm_starts( + a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, + layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas, + expert_offsets, sf_offsets, problem_sizes, M, N, K); + + // Create an instance of the GEMM + Gemm gemm_op; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast(problem_sizes.data_ptr()); + + // Set the Scheduler info + cutlass::KernelHardwareInfo hw_info; + using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = RasterOrderOptions::AlongM; + hw_info.device_id = a.get_device(); + static std::unordered_map cached_sm_counts; + if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { + cached_sm_counts[hw_info.device_id] = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX); + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides1.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides1.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(c_strides1.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides1.data_ptr())}; + auto& fusion_args = epilogue_args.thread; + fusion_args.alpha_ptr_array = + reinterpret_cast(alpha_ptrs.data_ptr()); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + fusion_args.beta = 0.0f; + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + size_t workspace_size = Gemm::get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM: status=", (int)can_implement_status); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Failed to initialize GEMM: status=", (int)status, + " workspace_size=", workspace_size, " num_experts=", num_experts, + " M=", M, " N=", N, " K=", K); + + status = gemm_op.run(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +template +void run_fp4_blockwise_scaled_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, + int N, int K) { + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 + if (version_num >= 120 && version_num < 130) { + run_fp4_blockwise_scaled_group_mm_sm120( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + return; + } +#endif #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 + if (version_num >= 100 && version_num < 120) { + run_fp4_blockwise_scaled_group_mm_sm100( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + return; + } +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ", + version_num, ". Required capability: 100 or 120"); +} + +#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ + (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; #endif @@ -374,7 +583,8 @@ void cutlass_fp4_group_mm( const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::Tensor& alphas, const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 +#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ + (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) // Input validation CHECK_INPUT(a, FLOAT4_E2M1X2, "a"); CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); @@ -408,6 +618,14 @@ void cutlass_fp4_group_mm( output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, expert_offsets, sf_offsets, M, N, K); } else { + #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 + int32_t version_num = get_sm_version_num(); + if (version_num >= 120 && version_num < 130) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ", + output.scalar_type()); + } + #endif run_fp4_blockwise_scaled_group_mm( output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, expert_offsets, sf_offsets, M, N, K); @@ -416,8 +634,8 @@ void cutlass_fp4_group_mm( TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled cutlass_fp4_group_mm kernel, vLLM must " - "be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA " - "12.8 or above."); + "be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 " + "and CUDA 12.8 or above."); #endif } diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 6d385e0dd94e7..82c53c2375a31 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -307,7 +307,7 @@ constexpr auto FLOAT = at::ScalarType::Float; constexpr auto INT = at::ScalarType::Int; constexpr auto UINT8 = at::ScalarType::Byte; -void scaled_fp4_experts_quant_sm100a( +void scaled_fp4_experts_quant_sm1xxa( torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index c2b39e5438805..fb6d22f035b99 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -24,8 +24,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input_sf); #endif -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 -void scaled_fp4_experts_quant_sm100a( +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void scaled_fp4_experts_quant_sm1xxa( torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, @@ -54,8 +55,9 @@ void scaled_fp4_experts_quant( torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts) { -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 - return scaled_fp4_experts_quant_sm100a( +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return scaled_fp4_experts_quant_sm1xxa( output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts); #endif diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu index 9cba2828aac2e..d9c4d24d8e1f2 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -15,6 +15,8 @@ */ #include +#include +#include "cutlass_extensions/common.hpp" #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, @@ -32,23 +34,34 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& alpha); #endif -void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha) { -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 - return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); -#elif defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 - return cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha); +void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& A_sf, + const torch::Tensor& B_sf, + const torch::Tensor& alpha) { + // Make sure we’re on A’s device. + const c10::cuda::OptionalCUDAGuard device_guard(device_of(A)); + const int32_t sm = get_sm_version_num(); + +#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 + if (sm >= 100 && sm < 120) { + cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); + return; + } #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, - "No compiled nvfp4 mm kernel, vLLM should " - "be compiled using CUDA 12.8 and target " - "compute capability 100 or above."); + +#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120 + if (sm >= 120 && sm < 130) { + cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha); + return; + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm, + ". Recompile with CUDA >= 12.8 and CC >= 100."); } bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) { int runtimeVersion; cudaRuntimeGetVersion(&runtimeVersion); return cuda_device_capability >= 100 && runtimeVersion >= 12080; -} \ No newline at end of file +} diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu index 03bd5964a7fc4..e306ff02605b9 100644 --- a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu +++ b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { #pragma unroll for (int k_idx = 0; k_idx < 2; ++k_idx) { - FType low16 = - ScalarType::float2num(C_frag[m_idx][n_idx][k_idx * 2]); - FType high16 = - ScalarType::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]); + FType low16 = MarlinScalarType2::float2num( + C_frag[m_idx][n_idx][k_idx * 2]); + FType high16 = MarlinScalarType2::float2num( + C_frag[m_idx][n_idx][k_idx * 2 + 1]); uint32_t tmp = (reinterpret_cast(low16) & 0xffff) | (reinterpret_cast(high16) << 16); int sts_offset = diff --git a/csrc/quantization/gptq_allspark/allspark_utils.cuh b/csrc/quantization/gptq_allspark/allspark_utils.cuh index 831413016538e..14a61ad8fd880 100644 --- a/csrc/quantization/gptq_allspark/allspark_utils.cuh +++ b/csrc/quantization/gptq_allspark/allspark_utils.cuh @@ -8,7 +8,7 @@ #include #include #include "../gptq_marlin/marlin_dtypes.cuh" -using marlin::ScalarType; +using marlin::MarlinScalarType2; namespace allspark { @@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C, int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix; for (int i = 0; i < n_mat; ++i) { - sum += ScalarType::num2float(C_split[idx + i * matrix_size]); + sum += MarlinScalarType2::num2float(C_split[idx + i * matrix_size]); } - C[idx] = ScalarType::float2num(sum); + C[idx] = MarlinScalarType2::float2num(sum); } template diff --git a/csrc/quantization/gptq_marlin/.gitignore b/csrc/quantization/gptq_marlin/.gitignore index 77088552b85b4..ba805f9250ece 100644 --- a/csrc/quantization/gptq_marlin/.gitignore +++ b/csrc/quantization/gptq_marlin/.gitignore @@ -1 +1,2 @@ -kernel_*.cu \ No newline at end of file +sm*_kernel_*.cu +kernel_selector.h diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu index e607107b3e77c..307bae6738ecf 100644 --- a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu @@ -4,14 +4,16 @@ namespace marlin { -template +template __global__ void awq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; + constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1); + constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1); + int k_tiles = size_k / target_tile_k_size; + int n_tiles = size_n / target_tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); auto start_k_tile = blockIdx.x * block_k_tiles; @@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel( extern __shared__ int4 sh[]; - constexpr int tile_n_ints = tile_n_size / pack_factor; + constexpr int tile_n_ints = target_tile_n_size / pack_factor; constexpr int stage_n_threads = tile_n_ints / 4; - constexpr int stage_k_threads = tile_k_size; + constexpr int stage_k_threads = target_tile_k_size; constexpr int stage_size = stage_k_threads * stage_n_threads; auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { @@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel( return; } - int first_n = n_tile_id * tile_n_size; + int first_n = n_tile_id * target_tile_n_size; int first_n_packed = first_n / pack_factor; int4* sh_ptr = sh + stage_size * pipe; @@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel( auto k_id = threadIdx.x / stage_n_threads; auto n_id = threadIdx.x % stage_n_threads; - int first_k = k_tile_id * tile_k_size; + int first_k = k_tile_id * target_tile_k_size; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], reinterpret_cast( @@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel( } int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; + int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2); constexpr int tc_offsets[4] = {0, 1, 8, 9}; - int cur_n = warp_id * 16 + tc_col; + int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col; int cur_n_packed = cur_n / pack_factor; int cur_n_pos = cur_n % pack_factor; @@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel( uint32_t vals[8]; #pragma unroll for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; + if constexpr (is_a_8bit) { + int cur_elem = tc_row + i; - int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; - int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + - sh_stride * cur_elem]; + int packed_src_0 = + sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) + + sh_stride * cur_elem]; + int packed_src_1 = + sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) + + sh_stride * (cur_elem + 16)]; - vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; - vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } else { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = + sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + constexpr int tile_size = + target_tile_k_size * target_tile_n_size / pack_factor; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + if constexpr (!is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else if constexpr (is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; uint32_t res = 0; #pragma unroll @@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel( uint32_t res2 = 0; #pragma unroll for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); + const int ii = is_a_8bit ? i : pack_idx[i]; + res1 |= vals[ii] << (i * 8); + res2 |= vals[4 + ii] << (i * 8); } out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; @@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel( } // namespace marlin -#define CALL_IF(NUM_BITS) \ - else if (num_bits == NUM_BITS) { \ - cudaFuncSetAttribute( \ - marlin::awq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - marlin::awq_marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, out_ptr, size_k, size_n); \ +#define CALL_IF(NUM_BITS, IS_A_8BIT) \ + else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \ + cudaFuncSetAttribute( \ + marlin::awq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::awq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, out_ptr, size_k, size_n); \ } torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, - int64_t size_n, int64_t num_bits) { + int64_t size_n, int64_t num_bits, + bool is_a_8bit) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", marlin::tile_k_size); @@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, if (false) { } - CALL_IF(4) - CALL_IF(8) + CALL_IF(4, false) + CALL_IF(8, false) + CALL_IF(4, true) + CALL_IF(8, true) else { - TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits); + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, + ", is_a_8bit = ", is_a_8bit); } return out; diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index e8b0c302b2021..26b8d40368aa9 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -470,6 +470,50 @@ __device__ inline void dequant( frag_b[0] = __hmul2(frag_b[0], bias_reg); } +template <> +__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>( + int q, __nv_fp8x4_e4m3* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4; + constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70707070; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); + + // Note1: reverse indexing is intentional because weights are permuted + // Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of + // `frag_b[1]` to fit the layout of tensorcore + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, int32_t* frag_b) { + constexpr int repeated_zp = 0x08080808; + constexpr int MASK = 0x80808080; + + frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; + q >>= 4; + frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; +} + +template <> +__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>( + int q, __nv_fp8x4_e4m3* frag_b) { + int s = q & 0x08080808; + int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3); + q >>= 4; + s = q & 0x08080808; + int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3); + + frag_b[0] = *reinterpret_cast(&Out1); + frag_b[1] = *reinterpret_cast(&Out2); +} + template __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); @@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales( // Note: reverse indexing is intentional because weights are permuted frag_b[1] = *reinterpret_cast(&Out1); frag_b[0] = *reinterpret_cast(&Out2); +}; + +// subtract zero point in quanted format and then dequant +template +__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp); + +template <> +__device__ inline void sub_zp_and_dequant( + int q, int32_t* frag_b, int zp) { + // INT4 with zp -> INT8 + // see https://github.com/vllm-project/vllm/pull/24722 + int repeated_zp = 0x01010101 * zp; + int MASK = 0x80808080; + + frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; + q >>= 4; + frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; +} + +template <> +__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(), + true>(int q, __nv_fp8x4_e4m3* frag_b, + int zp) { + // INT4 with zp -> FP8 + // see https://github.com/vllm-project/vllm/pull/24722 + uint32_t u_q = *reinterpret_cast(&q); + uint32_t u_zp = *reinterpret_cast(&zp); + uint32_t u_zp1 = u_zp + 1; + uint32_t repeated_zp = 0x01010101 * u_zp; + + uint32_t q0, s; + q0 = (u_q & 0x0F0F0F0F) | 0x70707070; + s = (q0 + repeated_zp) & 0x80808080; + uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s; + + u_q >>= 4; + q0 = (u_q & 0x0F0F0F0F) | 0x70707070; + s = (q0 + repeated_zp) & 0x80808080; + uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s; + + frag_b[0] = *reinterpret_cast(&Out1); + frag_b[1] = *reinterpret_cast(&Out2); } #endif diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 42d3b456096ee..27ef7271ba41c 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -4,141 +4,292 @@ import glob import itertools import os import subprocess +import sys import jinja2 -FILE_HEAD = """ -// auto generated by generate.py -// clang-format off +ARCHS = [] +SUPPORT_FP8 = False +for arch in sys.argv[1].split(","): + arch = arch[: arch.index(".") + 2].replace(".", "") + arch = int(arch) + # only SM89 and SM120 fully support + # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM90 and SM100 can use this PTX, but it’s simulated + # with FP16 MMA, so it cannot achieve any acceleration. + if arch in [89, 120]: + SUPPORT_FP8 = True +FILE_HEAD_COMMENT = """ +// auto generated by generate_kernels.py +// clang-format off +""".lstrip() + +FILE_HEAD = ( + FILE_HEAD_COMMENT + + """ #include "kernel.h" #include "marlin_template.h" namespace MARLIN_NAMESPACE_NAME { -""".strip() +""" +) TEMPLATE = ( "template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " + "{{a_type_id}}, " + "{{b_type_id}}, " + "{{c_type_id}}, " "{{s_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " + "{{m_block_size_8}}, " "{{stages}}, " "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" + "{{is_zp_float}}>" "( MARLIN_KERNEL_PARAMS );" ) -# int8 with zero point case (vllm::kU8) is also supported, -# we don't add it to reduce wheel size. -SCALAR_TYPES = [ - "vllm::kU4", - "vllm::kU4B8", - "vllm::kU8B128", - "vllm::kFE4M3fn", - "vllm::kFE2M1f", -] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] -# group_blocks: -# = 0 : act order case -# = -1 : channelwise quantization -# > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] -DTYPES = ["fp16", "bf16"] + +QUANT_CONFIGS = [ + # AWQ-INT4 + { + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 2, 4, 8], + }, + # HQQ + { + "a_type": ["kFloat16"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [4], + "is_zp_float": True, + }, + # GPTQ-INT4 + { + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # GPTQ-INT8 + { + "b_type": "kU8B128", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # FP8 + { + "b_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 8], + }, + # NVFP4 + { + "b_type": "kFE2M1f", + "s_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [1], + }, + # MXFP4 + { + "a_type": ["kBFloat16"], + "b_type": "kFE2M1f", + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [2], + }, + # AWQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # AWQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # MXFP4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kFE2M1f", + "c_type": ["kBFloat16"], + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [2], + }, +] def remove_old_kernels(): - for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"): subprocess.call(["rm", "-f", filename]) + filename = os.path.dirname(__file__) + "/kernel_selector.h" + subprocess.call(["rm", "-f", filename]) + def generate_new_kernels(): - for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + result_dict = {} + + for quant_config in QUANT_CONFIGS: + c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) + a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"]) + b_type = quant_config["b_type"] + is_zp_float = quant_config.get("is_zp_float", False) + all_group_blocks = quant_config["group_blocks"] + all_m_blocks = quant_config["thread_m_blocks"] + all_thread_configs = quant_config["thread_configs"] + + for a_type, c_type in itertools.product(a_types, c_types): + if not SUPPORT_FP8 and a_type == "kFE4M3fn": + continue + if "16" in a_type and "16" in c_type and a_type != c_type: + continue + s_type = quant_config.get("s_type", c_type) + if (a_type, b_type, c_type) not in result_dict: + result_dict[(a_type, b_type, c_type)] = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + all_group_blocks, all_m_blocks, all_thread_configs + ): + thread_k, thread_n, threads = thread_configs + + if threads == 256: + # for small batch (m_blocks == 1), + # we only need (128, 128, 256) + # for large batch (m_blocks > 1), + # we only need (64, 256, 256) + if m_blocks <= 1 and (thread_k, thread_n) != (128, 128): + continue + if m_blocks > 1 and (thread_k, thread_n) != (64, 256): + continue + + config = { + "threads": threads, + "s_type": s_type, + "thread_m_blocks": max(m_blocks, 1), + "thread_k_blocks": thread_k // 16, + "thread_n_blocks": thread_n // 16, + "m_block_size_8": "true" if m_blocks == 0.5 else "false", + "stages": "pipe_stages", + "group_blocks": group_blocks, + "is_zp_float": "true" if is_zp_float else "false", + } + + result_dict[(a_type, b_type, c_type)].append(config) + + kernel_selector_str = FILE_HEAD_COMMENT + + for (a_type, b_type, c_type), config_list in result_dict.items(): all_template_str_list = [] + for config in config_list: + s_type = config["s_type"] + template_str = jinja2.Template(TEMPLATE).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + all_template_str_list.append(template_str) - for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS - ): - # act order case only support gptq-int4 and gptq-int8 - if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", - "vllm::kU8B128", - ]: - continue - if thread_configs[2] == 256: - # for small batch (m_blocks == 1), we only need (128, 128, 256) - # for large batch (m_blocks > 1), we only need (64, 256, 256) - if m_blocks <= 1 and thread_configs[0] != 128: - continue - if m_blocks > 1 and thread_configs[0] != 64: - continue + conditions = [ + f"a_type == vllm::{a_type}", + f"b_type == vllm::{b_type}", + f"c_type == vllm::{c_type}", + f"s_type == vllm::{s_type}", + f"threads == {config['threads']}", + f"thread_m_blocks == {config['thread_m_blocks']}", + f"thread_n_blocks == {config['thread_n_blocks']}", + f"thread_k_blocks == {config['thread_k_blocks']}", + f"m_block_size_8 == {config['m_block_size_8']}", + f"group_blocks == {config['group_blocks']}", + f"is_zp_float == {config['is_zp_float']}", + ] + conditions = " && ".join(conditions) - # we only support channelwise quantization and group_size == 128 - # for fp8 - if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: - continue - # nvfp4 only supports group_size == 16 - # mxfp4 only supports group_size == 32 - if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: - continue - # other quantization methods don't support group_size = 16 - if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: - continue + if kernel_selector_str == FILE_HEAD_COMMENT: + kernel_selector_str += f"if ({conditions})\n kernel = " + else: + kernel_selector_str += f"else if ({conditions})\n kernel = " - k_blocks = thread_configs[0] // 16 - n_blocks = thread_configs[1] // 16 - threads = thread_configs[2] + kernel_template2 = ( + "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " + "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " + "{{thread_n_blocks}}, {{thread_k_blocks}}, " + "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " + "{{is_zp_float}}>;" + ) - c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" - - is_zp_float_list = [False] - if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4: - # HQQ (is_zp_float = true) only supports - # 4bit quantization and fp16 - is_zp_float_list.append(True) - - if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: - s_type = "vllm::kFE4M3fn" - elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: - s_type = "vllm::kFE8M0fnu" - if dtype == "fp16": - # we cannot safely dequantize e8m0 to fp16, so skip this - continue - elif dtype == "fp16": - s_type = "vllm::kFloat16" - elif dtype == "bf16": - s_type = "vllm::kBFloat16" - - for is_zp_float in is_zp_float_list: - template_str = jinja2.Template(TEMPLATE).render( - scalar_t=c_dtype, - w_type_id=scalar_type + ".id()", - s_type_id=s_type + ".id()", - threads=threads, - thread_m_blocks=max(m_blocks, 1), - thread_n_blocks=n_blocks, - thread_k_blocks=k_blocks, - m_block_size_8=m_blocks == 0.5, - stages="pipe_stages", - group_blocks=group_blocks, - is_zp_float=is_zp_float, + kernel_selector_str += ( + jinja2.Template(kernel_template2).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, ) - - all_template_str_list.append(template_str) + + "\n" + ) file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + if a_type == "kFE4M3fn": + filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + else: + filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + + filename = filename.lower() with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) + if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: + kernel_selector_str += ( + "else if (a_type == vllm::kFE4M3fn)\n" + " TORCH_CHECK(false, " + '"marlin kernel with fp8 activation is not built.");' + ) + + with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f: + f.write(kernel_selector_str) + if __name__ == "__main__": remove_old_kernels() diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index cc30abcf00800..28ff06559a98a 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -53,7 +53,7 @@ torch::Tensor gptq_marlin_gemm( std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { TORCH_CHECK_NOT_IMPLEMENTED(false, @@ -243,204 +243,29 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, int cache_size = get_kernel_cache_size( th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return cache_size + 512 <= max_shared_mem; + return cache_size <= max_shared_mem; } - #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - constexpr auto S_TYPE = \ - W_TYPE == vllm::kFE2M1f \ - ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ - : (std::is_same::value ? vllm::kFloat16 \ - : vllm::kBFloat16); \ - kernel = Marlin; \ - } - - // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) - // this is the most common cases - // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) - // FZP: cases for float-zero-point (is_zp_float = true) - // ACT: cases for act order case (group_blocks == 0) - // FP4: cases for nvfp4(e2m1) (group_blocks == 1) - #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF(W_TYPE) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF(W_TYPE) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 4, 8, 128) - - // We currently have 4-bit models only with group_blocks == 4 - #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 4, 8, 128) - -template -MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, - int thread_m_blocks, int thread_n_blocks, - int thread_k_blocks, bool m_block_size_8, - bool has_act_order, bool has_zp, - int group_blocks, int num_threads, - bool is_zp_float) { - int num_bits = q_type.size_bits(); +MarlinFuncPtr get_marlin_kernel( + const vllm::ScalarType a_type, const vllm::ScalarType b_type, + const vllm::ScalarType c_type, const vllm::ScalarType s_type, + int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, + bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, + int threads, bool is_zp_float) { + int num_bits = b_type.size_bits(); auto kernel = MarlinDefault; - if (false) { - } - COMMON_GET_IF(vllm::kU4) - COMMON_GET_IF(vllm::kU4B8) - COMMON_GET_IF(vllm::kU8B128) - - NVFP4_GET_IF(vllm::kFE2M1f) - - BIGGROUP_GET_IF(vllm::kFE4M3fn) - - ACT_GET_IF(vllm::kU4B8) - ACT_GET_IF(vllm::kU8B128) - - if (std::is_same::value) { - if (false) { - } - FZP_GET_IF(vllm::kU4) - } - if (std::is_same::value) { - if (false) { - } - MXFP4_GET_IF(vllm::kFE2M1f) - } + #include "kernel_selector.h" return kernel; } -template -exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, - int prob_n, int prob_k, int thread_m_blocks, - bool m_block_size_8, int num_bits, - int group_size, bool has_act_order, - bool is_k_full, bool has_zp, - bool is_zp_float, int max_shared_mem, - int sms) { +exec_config_t determine_exec_config( + const vllm::ScalarType& a_type, const vllm::ScalarType& b_type, + const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m, + int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8, + int num_bits, int group_size, bool has_act_order, bool is_k_full, + bool has_zp, bool is_zp_float, int max_shared_mem, int sms) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs @@ -455,7 +280,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, - is_zp_float, max_shared_mem)) { + is_zp_float, max_shared_mem - 512)) { continue; } @@ -468,10 +293,11 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, group_blocks = group_size == -1 ? -1 : group_size / 16; } - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, th_config.thread_n / 16, - th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, - group_blocks, th_config.num_threads, is_zp_float); + auto kernel = + get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, + th_config.thread_n / 16, th_config.thread_k / 16, + m_block_size_8, has_act_order, has_zp, group_blocks, + th_config.num_threads, is_zp_float); if (kernel == MarlinDefault) continue; @@ -485,28 +311,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, return exec_cfg; } -template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, - void* s, void* s2, void* zp, void* g_idx, void* perm, - void* a_tmp, int prob_m, int prob_n, int prob_k, int lda, - void* workspace, vllm::ScalarType const& q_type, bool has_bias, + void* a_s, void* b_s, void* g_s, void* zp, void* g_idx, + void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k, + int lda, void* workspace, vllm::ScalarType const& a_type, + vllm::ScalarType const& b_type, vllm::ScalarType const& c_type, + vllm::ScalarType const& s_type, bool has_bias, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k_init, int thread_n_init, int sms, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { - if (has_zp) { - TORCH_CHECK( - q_type == vllm::kU4 || q_type == vllm::kU8, - "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); - } else { - TORCH_CHECK( - q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, - "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - q_type.str()); - } - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -531,19 +345,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, } } - int num_bits = q_type.size_bits(); + int num_bits = b_type.size_bits(); const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; + const int4* bias_ptr = (const int4*)b_bias; - const int4* s_ptr = (const int4*)s; - const uint16_t* s2_ptr = (const uint16_t*)s2; + const float* a_s_ptr = (const float*)a_s; + const int4* b_s_ptr = (const int4*)b_s; + const uint16_t* g_s_ptr = (const uint16_t*)g_s; + const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; int4* a_tmp_ptr = (int4*)a_tmp; - int* locks = (int*)workspace; if (has_act_order) { @@ -568,6 +384,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); + int major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + dev); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + dev); + TORCH_CHECK(major_capability * 10 + minor_capability >= 80, + "marlin kernel only support Ampere or newer GPUs."); + if (a_type == vllm::kFE4M3fn) { + TORCH_CHECK( + major_capability * 10 + minor_capability == 89 || + major_capability * 10 + minor_capability == 120, + "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " + "Marlin W4A16 on other devices)."); + } + int max_par = 16; if (prob_n <= 4096) max_par = 16 * 8; int max_shared_mem_new = max_shared_mem; @@ -583,7 +414,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, int thread_n = thread_n_init; int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); - int m_block_size_8 = prob_m_split <= 8; + int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16; // Set thread config exec_config_t exec_cfg; @@ -597,11 +428,25 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, " is not divisible by thread_k = ", thread_k); } else { // Auto config - exec_cfg = determine_exec_config( - q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8, - num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, - max_shared_mem, sms); + exec_cfg = determine_exec_config( + a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k, + thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem, sms); thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_n != -1) { + if (prob_n / thread_tfg.thread_n * + div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <= + sms) { + if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split, + prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem_new)) { + thread_tfg = {128, 64, 128}; + exec_cfg = {1, thread_tfg}; + } + } + } + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { max_thread_m_blocks--; continue; @@ -632,10 +477,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, ", max_shared_mem_new = ", max_shared_mem_new); - auto kernel = get_marlin_kernel( - q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, - m_block_size_8, has_act_order, has_zp, group_blocks, num_threads, - is_zp_float); + auto kernel = get_marlin_kernel( + a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, + thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, + num_threads, is_zp_float); if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, @@ -657,13 +502,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr, num_groups, prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add, use_fp32_reduce, max_shared_mem_new); // clang-format on - A_ptr += prob_m_split * (lda / 8); + bool is_a_8bit = a_type.size_bits() == 8; + A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8)); + a_s_ptr += prob_m_split; C_ptr += prob_m_split * (prob_n / 8); rest_m -= prob_m_split; } @@ -675,15 +522,73 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, std::optional const& b_bias_or_none, torch::Tensor& b_scales, + std::optional const& a_scales_or_none, std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, - vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { - vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - int pack_factor = 32 / b_q_type.size_bits(); + vllm::ScalarTypeId a_type_id, c_type_id, s_type_id; + + auto c_dtype = a.dtype(); + if (a.scalar_type() == at::ScalarType::Half) { + a_type_id = vllm::kFloat16.id(); + c_type_id = vllm::kFloat16.id(); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + a_type_id = vllm::kBFloat16.id(); + c_type_id = vllm::kBFloat16.id(); + } else { + c_dtype = b_scales.dtype(); + if (b_scales.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (b_scales.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + c_type_id = vllm::kBFloat16.id(); + + TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4"); + torch::Tensor c = c_or_none.value(); + c_dtype = c.dtype(); + + if (c.scalar_type() == at::ScalarType::Half) { + c_type_id = vllm::kFloat16.id(); + } else if (c.scalar_type() == at::ScalarType::BFloat16) { + c_type_id = vllm::kBFloat16.id(); + } else { + TORCH_CHECK(false, "unsupported c dtype"); + } + } + + if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) { + a_type_id = vllm::kFE4M3fn.id(); + } else if (a.scalar_type() == at::ScalarType::Char) { + a_type_id = vllm::kS8.id(); + } else { + TORCH_CHECK(false, "unsupported `a` scalar_type"); + } + } + + s_type_id = c_type_id; + if (b_type_id == vllm::kFE2M1f.id()) { + if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) { + s_type_id = vllm::kFE4M3fn.id(); + } else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + s_type_id = vllm::kFE8M0fnu.id(); + } else { + TORCH_CHECK(false, + "When b_type = float4_e2m1f, b_scale scalar type must be", + "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); + } + } + + vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); + vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id); + vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id); + vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id); + + int pack_factor = 32 / b_type.size_bits(); // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), @@ -721,6 +626,21 @@ torch::Tensor gptq_marlin_gemm( TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + torch::Tensor a_scales; + auto options = torch::TensorOptions().dtype(c_dtype).device(a.device()); + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + + if (a_scales_or_none.has_value()) { + a_scales = a_scales_or_none.value(); + TORCH_CHECK(a_type.size_bits() == 8, + "a_scales can only be used for 8bit activation."); + } else { + a_scales = torch::empty({0}, options_fp32); + TORCH_CHECK(a_type.size_bits() != 8, + "the a_scales parameter must be passed for 8bit activation."); + } + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_k = -1; @@ -733,7 +653,6 @@ torch::Tensor gptq_marlin_gemm( // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c; if (c_or_none.has_value()) { c = c_or_none.value(); @@ -750,8 +669,6 @@ torch::Tensor gptq_marlin_gemm( // Alloc C tmp buffer that is going to be used for the global reduce torch::Tensor c_tmp; - auto options_fp32 = - torch::TensorOptions().dtype(at::kFloat).device(a.device()); if (use_fp32_reduce) { int max_m_block_size = (size_m + 16 - 1) / 16 * 16; max_m_block_size = min(max_m_block_size, 64); @@ -821,11 +738,11 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor global_scale; if (global_scale_or_none.has_value()) { global_scale = global_scale_or_none.value(); - TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, + TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, "global_scale can only be used for nvfp4 format."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), + TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), "the global_scale parameter must be passed for nvfp4 format."); } @@ -852,15 +769,15 @@ torch::Tensor gptq_marlin_gemm( bool has_zp = b_zeros.size(-1) > 0; if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4 || b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + b_type == vllm::kU4 || b_type == vllm::kU8, + "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str()); } else { - TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, - "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " - "float4_e2m1f when " - "has_zp = False. Got = ", - b_q_type.str()); + TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f, + "b_type must be uint4b8, uint8b128, int4, int8, " + "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ", + b_type.str()); } if (has_zp && is_zp_float) { @@ -902,59 +819,27 @@ torch::Tensor gptq_marlin_gemm( " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - marlin::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, - is_k_full, has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - void* scales_ptr; - if (b_q_type == vllm::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, - "float4_e2m1f only supports group_size == 16 (NVFP4) ", - "and group_size == 32 (MXFP4)"); - } else { - scales_ptr = b_scales.data_ptr(); - } - - marlin::marlin_mm( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, - has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - use_atomic_add, use_fp32_reduce, is_zp_float); - } else { - TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float, + "scalar type of a_scales must be float"); + TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(), + "scalar type of global_scale must be the same with c"); + if (a_type.size_bits() == 16) { + TORCH_CHECK( + a.scalar_type() == c.scalar_type(), + "scalar type of a must be the same with c for 16 bit activation"); } + marlin::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), + b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(), + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), + workspace.data_ptr(), a_type, b_type, c_type, s_type, has_bias, + has_act_order, is_k_full, has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); + return c; } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index ad80d51ece94e..796e6c5359da1 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -4,15 +4,18 @@ namespace marlin { -template +template __global__ void gptq_marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; - int k_tiles = size_k / tile_k_size; - int n_tiles = size_n / tile_n_size; + constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1); + constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1); + int k_tiles = size_k / target_tile_k_size; + int n_tiles = size_n / target_tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); auto start_k_tile = blockIdx.x * block_k_tiles; @@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel( extern __shared__ int4 sh[]; - constexpr int perm_size = tile_k_size / 4; + constexpr int perm_size = target_tile_k_size / 4; int4* sh_perm_ptr = sh; int4* sh_pipe_ptr = sh_perm_ptr; @@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel( sh_pipe_ptr += perm_size; } - constexpr int tile_ints = tile_k_size / pack_factor; + constexpr int tile_ints = target_tile_k_size / pack_factor; - constexpr int stage_n_threads = tile_n_size / 4; - constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_n_threads = target_tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints; constexpr int stage_size = stage_k_threads * stage_n_threads; auto load_perm_to_shared = [&](int k_tile_id) { - int first_k_int4 = (k_tile_id * tile_k_size) / 4; + int first_k_int4 = (k_tile_id * target_tile_k_size) / 4; int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); @@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel( return; } - int first_n = n_tile_id * tile_n_size; + int first_n = n_tile_id * target_tile_n_size; int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; @@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel( auto k_id = threadIdx.x / stage_n_threads; auto n_id = threadIdx.x % stage_n_threads; - int first_k = k_tile_id * tile_k_size; + int first_k = k_tile_id * target_tile_k_size; int first_k_packed = first_k / pack_factor; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], @@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel( } int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; + int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2); constexpr int tc_offsets[4] = {0, 1, 8, 9}; - int cur_n = warp_id * 16 + tc_col; + int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col; - constexpr int sh_stride = 64; + constexpr int sh_stride = target_tile_n_size; constexpr uint32_t mask = (1 << num_bits) - 1; int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; @@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel( uint32_t vals[8]; if constexpr (has_perm) { + static_assert(!is_a_8bit); for (int i = 0; i < 4; i++) { int k_idx = tc_row + tc_offsets[i]; @@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel( #pragma unroll for (int i = 0; i < tile_ints; i++) { - b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; - b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + if constexpr (is_a_8bit) { + b1_vals[i] = + sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8]; + } else { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } } #pragma unroll for (int i = 0; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i]; + int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]); int cur_int = cur_elem / pack_factor; int cur_pos = cur_elem % pack_factor; vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; - vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + if constexpr (is_a_8bit) + vals[4 + i] = + (b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask; + else + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; } } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + constexpr int tile_size = + target_tile_k_size * target_tile_n_size / pack_factor; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + if constexpr (!is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else if constexpr (is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; uint32_t res = 0; #pragma unroll @@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel( uint32_t res2 = 0; #pragma unroll for (int i = 0; i < 4; i++) { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); + const int ii = is_a_8bit ? i : pack_idx[i]; + res1 |= vals[ii] << (i * 8); + res2 |= vals[4 + ii] << (i * 8); } out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; @@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel( } // namespace marlin -#define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ +#define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM && \ + is_a_8bit == IS_A_8BIT) { \ cudaFuncSetAttribute( \ marlin::gptq_marlin_repack_kernel, \ + HAS_PERM, IS_A_8BIT>, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ marlin::gptq_marlin_repack_kernel \ + HAS_PERM, IS_A_8BIT> \ <<>>( \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ } torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, - int64_t num_bits) { + int64_t num_bits, bool is_a_8bit) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", marlin::tile_k_size); @@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, if (false) { } - CALL_IF(4, false) - CALL_IF(4, true) - CALL_IF(8, false) - CALL_IF(8, true) + CALL_IF(4, false, false) + CALL_IF(4, true, false) + CALL_IF(8, false, false) + CALL_IF(8, true, false) + + CALL_IF(4, false, true) + CALL_IF(8, false, true) + else { TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, - ", has_perm = ", has_perm); + ", has_perm = ", has_perm, ", is_a_8bit = ", is_a_8bit); } return out; diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h index bb454f6aff22a..b3b79c8aec452 100644 --- a/csrc/quantization/gptq_marlin/kernel.h +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -11,17 +11,19 @@ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ const int4 *__restrict__ b_bias_ptr, \ + const float *__restrict__ a_scales_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ scale2_ptr, \ + const uint16_t *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ int max_shared_mem namespace MARLIN_NAMESPACE_NAME { -template (__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 8; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; diff --git a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh index cc16054814342..a4807a6887f81 100644 --- a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh @@ -2,8 +2,10 @@ #ifndef _data_types_cuh #define _data_types_cuh #include "marlin.cuh" +#include "core/scalar_type.hpp" #include #include +#include #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin @@ -11,14 +13,16 @@ namespace MARLIN_NAMESPACE_NAME { -template -class ScalarType {}; +template +class MarlinScalarType {}; template <> -class ScalarType { +class MarlinScalarType { public: using scalar_t = half; using scalar_t2 = half2; + using scalar_t4 = half2; + using scalar_32bit_t = half2; // Matrix fragments for tensor core instructions; their precise layout is // documented here: @@ -27,6 +31,7 @@ class ScalarType { using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragS0 = Vec<__nv_fp8x2_e4m3, 1>; using FragZP = Vec; static __device__ float inline num2float(const half x) { @@ -44,18 +49,25 @@ class ScalarType { static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } + + static __host__ __device__ float2 inline num22float2(const half2 x) { + return __half22float2(x); + } }; template <> -class ScalarType { +class MarlinScalarType { public: using scalar_t = nv_bfloat16; using scalar_t2 = nv_bfloat162; + using scalar_t4 = nv_bfloat162; + using scalar_32bit_t = nv_bfloat162; using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; + using FragS0 = Vec<__nv_fp8x2_e4m3, 1>; using FragZP = Vec; #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 @@ -75,9 +87,63 @@ class ScalarType { static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } + + static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) { + return __bfloat1622float2(x); + } #endif }; +template <> +class MarlinScalarType { + public: + using scalar_t = __nv_fp8_e4m3; + using scalar_t2 = __nv_fp8x2_e4m3; + using scalar_t4 = __nv_fp8x4_e4m3; + using scalar_32bit_t = __nv_fp8x4_e4m3; + + using FragA = Vec<__nv_fp8x4_e4m3, 4>; + using FragB = Vec<__nv_fp8x4_e4m3, 2>; + using FragC = Vec; + using FragZP = Vec<__nv_fp8x2_e4m3, 4>; + + static __host__ __device__ + float2 inline num22float2(const __nv_fp8x2_e4m3 x) { + return (float2)x; + } +}; + +template <> +class MarlinScalarType { + public: + using scalar_t = int8_t; + using scalar_t2 = int16_t; + using scalar_t4 = int32_t; + using scalar_32bit_t = int32_t; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragZP = Vec; +}; + +template +class MarlinScalarType2 {}; + +template <> +class MarlinScalarType2 : public MarlinScalarType {}; + +template <> +class MarlinScalarType2 + : public MarlinScalarType {}; + +template <> +class MarlinScalarType2<__nv_fp8_e4m3> + : public MarlinScalarType {}; + +template <> +class MarlinScalarType2 : public MarlinScalarType {}; + } // namespace MARLIN_NAMESPACE_NAME #endif diff --git a/csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu b/csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu new file mode 100644 index 0000000000000..7d4c97fb57ed4 --- /dev/null +++ b/csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu @@ -0,0 +1,106 @@ + + +#include "marlin.cuh" + +#include "core/registration.h" + +// for only non-zp format (like gptq) +__global__ void marlin_int4_fp8_preprocess_kernel_without_zp( + // qweight: (size_k * size_n // 8,) + const int32_t* __restrict__ qweight, + // output: same shape with qweight + int32_t* __restrict__ output) { + int32_t val = qweight[blockIdx.x * 32 + threadIdx.x]; + int32_t new_val = 0; + +#pragma unroll + for (int32_t i = 0; i < 8; i++) { + int32_t single_val = val & 0xF; + single_val = single_val >= 8 ? single_val - 8 : 15 - single_val; + new_val |= single_val << (i * 4); + val >>= 4; + } + + output[blockIdx.x * 32 + threadIdx.x] = new_val; +} + +// for awq format only (with zp and with awq weight layout) +__global__ void marlin_int4_fp8_preprocess_kernel_awq( + // AWQ qweight: (size_k, size_n // 8) + const int32_t* __restrict__ qweight, + // output: same shape with qweight + int32_t* __restrict__ output, + // AWQ zeros: (size_k // group_size, size_n // 8) + const int32_t* __restrict__ qzeros, int32_t size_n, int32_t size_k, + int32_t group_size) { + int32_t val = + qweight[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y]; + int32_t zero = + qzeros[(blockIdx.x * 32 + threadIdx.x) / group_size * size_n / 8 + + blockIdx.y]; + int32_t new_val = 0; + +#pragma unroll + for (int32_t i = 0; i < 8; i++) { + int32_t single_val = val & 0xF; + int32_t single_zero = zero & 0xF; + + single_val = + single_val >= single_zero ? single_val - single_zero : 15 - single_val; + new_val |= single_val << (i * 4); + val >>= 4; + zero >>= 4; + } + + output[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y] = new_val; +} + +torch::Tensor marlin_int4_fp8_preprocess( + torch::Tensor& qweight, std::optional qzeros_or_none, + bool inplace) { + TORCH_CHECK(qweight.device().is_cuda(), "qweight is not on GPU"); + TORCH_CHECK(qweight.scalar_type() == at::ScalarType::Int, + "qweight.dtype != torch.int32"); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight)); + + torch::Tensor output = inplace ? qweight : torch::empty_like(qweight); + + if (!qzeros_or_none.has_value()) { + TORCH_CHECK(qweight.numel() * 8 % 256 == 0, + "qweight.numel() * 8 % 256 != 0"); + + int blocks = qweight.numel() * 8 / 256; + marlin_int4_fp8_preprocess_kernel_without_zp<<>>( + (const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr()); + } else { + int32_t size_k = qweight.size(0); + int32_t size_n = qweight.size(1) * 8; + torch::Tensor qzeros = qzeros_or_none.value(); + + TORCH_CHECK(size_k % 32 == 0, "size_k % 32 != 0"); + TORCH_CHECK(qzeros.device().is_cuda(), "qzeros is not on GPU"); + TORCH_CHECK(qzeros.scalar_type() == at::ScalarType::Int, + "qweight.dtype != torch.int32"); + TORCH_CHECK(device_of(qweight) == device_of(qzeros), + "qzeros is not on the same device with qweight"); + + int32_t group_size = qweight.size(0) / qzeros.size(0); + TORCH_CHECK(qweight.size(1) == qzeros.size(1), + "qweight.size(1) != qzeros.size(1)"); + TORCH_CHECK(qweight.size(0) % qzeros.size(0) == 0, + "qweight.size(0) % qzeros.size(0) != 0"); + TORCH_CHECK(group_size % 8 == 0, "group_size % 8 != 0"); + + dim3 blocks(size_k / 32, size_n / 8); + marlin_int4_fp8_preprocess_kernel_awq<<>>( + (const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr(), + (const int32_t*)qzeros.data_ptr(), size_n, size_k, group_size); + } + + return output; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("marlin_int4_fp8_preprocess", &marlin_int4_fp8_preprocess); +} diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index bfb0a3668f527..22bb71e482ce8 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -38,7 +38,7 @@ namespace MARLIN_NAMESPACE_NAME { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { +template +__device__ inline void mma( + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragC& frag_c, int idx = 0) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]), + "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]), + "r"(c[1]), "r"(c[2]), "r"(c[3])); + } + } else if (k_size == 32) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } -template +template __device__ inline void mma_trans( - const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - const typename ScalarType::FragB& frag_b2, - typename ScalarType::FragC& frag_c) { + const typename MarlinScalarType::FragA& a_frag, + const typename MarlinScalarType::FragB& frag_b, + const typename MarlinScalarType::FragB& frag_b2, + typename MarlinScalarType::FragC& frag_c) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); const uint32_t* b2 = reinterpret_cast(&frag_b2); float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (k_size == 16) { + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), + "r"(c[3])); + } } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + } } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -template -__device__ inline void ldsm(typename ScalarType::FragA& frag_a, +template +__device__ inline void ldsm(typename MarlinScalarType::FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); @@ -159,47 +233,54 @@ __device__ inline void ldsm(typename ScalarType::FragA& frag_a, // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, +template +__device__ inline void scale(typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s, int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s = MarlinScalarType::num2num2( + reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } -template +template __device__ inline void scale_and_sub( - typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s2 = ScalarType::num2num2(s); - scalar_t2 zp2 = ScalarType::num2num2(zp); + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t s, + typename MarlinScalarType::scalar_t zp) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s2 = MarlinScalarType::num2num2(s); + scalar_t2 zp2 = MarlinScalarType::num2num2(zp); frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); } -template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); +template +__device__ inline void sub_zp( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::scalar_t2& frag_zp, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 zp = MarlinScalarType::num2num2( + reinterpret_cast(&frag_zp)[i]); frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); } // Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; +template +__device__ inline void scale4( + typename MarlinScalarType::FragB& frag_b, + typename MarlinScalarType::FragS& frag_s_1, + typename MarlinScalarType::FragS& frag_s_2, + typename MarlinScalarType::FragS& frag_s_3, + typename MarlinScalarType::FragS& frag_s_4, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s_val_1_2; s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; @@ -213,12 +294,13 @@ __device__ inline void scale4(typename ScalarType::FragB& frag_b, } // Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { +template +__device__ inline void scale_float( + float* c, typename MarlinScalarType::FragS& s) { + using scalar_t = typename MarlinScalarType::scalar_t; scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); + c[0] = __fmul_rn(c[0], MarlinScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], MarlinScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. @@ -270,9 +352,10 @@ __device__ inline void wait_negative_and_add(int* lock) { __syncthreads(); } -template __global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ A0, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C0, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) const int4* __restrict__ b_bias_ptr, - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 - // only) - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k + // float scales of input matrix, only used when is_a_8bit == true. + // shape (m,) + const float* __restrict__ a_scales_ptr, + // fp16 quantization scales. shape (k/groupsize, n) + const int4* __restrict__ scales_ptr, + // fp16 global scale (for nvfp4// only) + const uint16_t* __restrict__ global_scale_ptr, + // 4bit packed zero-points of shape + // (k/groupsize, n/pack_factor) + const int4* __restrict__ zp_ptr, + // int32 group indices of shape k + const int* __restrict__ g_idx, int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n @@ -321,17 +409,35 @@ __global__ void Marlin( // ensures good utilization of all SMs for many kinds of shape and GPU // configurations, while requiring as few slow global cross-threadblock // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - using FragZP = typename ScalarType::FragZP; - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 890 + // FP8 computation is only supported for Ada Lovelace or newer architectures. + if constexpr (a_type_id == vllm::kFE4M3fn.id()) return; + #endif + + using Adtype = MarlinScalarType; + using Cdtype = MarlinScalarType; + const int4* A = A0; + int4* C = C0; + + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + using scalar_32bit_t = typename MarlinScalarType::scalar_32bit_t; + + using c_scalar_t = typename MarlinScalarType::scalar_t; + using c_scalar_t2 = typename MarlinScalarType::scalar_t2; + + using FragA = typename MarlinScalarType::FragA; + using FragB = typename MarlinScalarType::FragB; + using FragC = typename MarlinScalarType::FragC; + using FragS = typename MarlinScalarType::FragS; + using FragZP = typename MarlinScalarType::FragZP; + + static constexpr auto a_type = vllm::ScalarType::from_id(a_type_id); + static constexpr auto b_type = vllm::ScalarType::from_id(b_type_id); + static constexpr auto c_type = vllm::ScalarType::from_id(c_type_id); static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || s_type == vllm::kFE8M0fnu && group_blocks == 2); } else if constexpr (std::is_same::value) { @@ -340,27 +446,35 @@ __global__ void Marlin( static_assert(s_type == vllm::kFloat16); } - constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; - constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || - w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + constexpr bool is_a_8bit = a_type.size_bits() == 8; + if constexpr (!is_a_8bit) { + static_assert(std::is_same::value); + } + constexpr bool has_zp = b_type == vllm::kU4 || b_type == vllm::kU8; + constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 || + b_type == vllm::kS4 || b_type == vllm::kS8 || + b_type == vllm::kU4B8 || b_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - w_type == vllm::kFE4M3fn || - w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || + is_a_8bit || b_type == vllm::kFE4M3fn || + b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || - has_zp && !is_zp_float && !(w_type == vllm::kU8); + has_zp && !is_zp_float && !(b_type == vllm::kU8); - scalar_t2 global_scale; - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - // NVFP4 format requires global scale - uint16_t val = scale2_ptr[0]; - global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + c_scalar_t2 global_scale; + + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + uint16_t val = global_scale_ptr[0]; + global_scale = Cdtype::num2num2(*reinterpret_cast(&val)); } constexpr bool has_act_order = group_blocks == 0; constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); - constexpr int pack_factor = 32 / w_type.size_bits(); + extern __shared__ int4 sh[]; + float* sh_a_s = reinterpret_cast(sh); + int4* sh_new = sh + (is_a_8bit ? (4 * thread_m_blocks) : 0); + constexpr int pack_factor = 32 / b_type.size_bits(); static_assert(thread_m_blocks == 1 || !m_block_size_8); // For larger GEMMs we run multiple batchsize 64 versions in parallel for a @@ -373,7 +487,19 @@ __global__ void Marlin( int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + int global_mn_tiles = parallel * n_tiles; + int part2_mn_tiles = global_mn_tiles; + int part1_mn_iters = 0; + bool in_part2 = false; + + if (global_mn_tiles > gridDim.x) { + part2_mn_tiles = global_mn_tiles % gridDim.x; + if (part2_mn_tiles * 3 <= gridDim.x) part2_mn_tiles += gridDim.x; + part1_mn_iters = (global_mn_tiles - part2_mn_tiles) / gridDim.x; + } + + int iters = div_ceil(k_tiles * part2_mn_tiles, gridDim.x); if constexpr (!has_act_order && group_blocks != -1) { if (group_blocks >= thread_k_blocks) { @@ -385,28 +511,21 @@ __global__ void Marlin( } } - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + int slice_row = 0; + int slice_col_par = blockIdx.x; + int slice_col; + int slice_iters = + k_tiles; // number of threadblock tiles in the current slice + // total number of active threadblocks in the current slice + int slice_count = 1; + // index of threadblock in current slice; numbered bottom to top + int slice_idx = 0; int par_id = 0; int locks_off = 0; - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - slice_col = slice_col_par % n_tiles; - par_id = slice_col_par / n_tiles; - } - if (parallel * n_tiles >= gridDim.x) { - // when parallel * n_tiles >= sms + if (part2_mn_tiles >= gridDim.x) { + // when part2_mn_tiles >= sms // then there are at most $sms$ conflict tile blocks locks_off = blockIdx.x; } else { @@ -415,10 +534,11 @@ __global__ void Marlin( // Compute all information about the current slice which is required for // synchronization. - auto init_slice = [&](bool first_init = false) { + bool first_init = true; + auto init_part2_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters < 0 || slice_col_par >= part2_mn_tiles) slice_iters = 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; @@ -436,7 +556,7 @@ __global__ void Marlin( if (col_off > 0) slice_idx--; } } - if (parallel * n_tiles >= gridDim.x) { + if (part2_mn_tiles >= gridDim.x) { if (slice_count > 1 && slice_idx == slice_count - 1) { locks_off++; } @@ -466,28 +586,68 @@ __global__ void Marlin( } if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * lda / 8; + A += 16 * thread_m_blocks * lda / (is_a_8bit ? 16 : 8); C += 16 * thread_m_blocks * prob_n / 8; slice_col = 0; par_id++; } + if (is_a_8bit && (first_init || slice_col == 0)) { + __syncthreads(); + int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x; + cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd], + threadIdx.x < prob_m); + } }; - init_slice(true); + + auto init_part1_slice = [&]() { + if (part1_mn_iters) { + part1_mn_iters--; + par_id = slice_col_par / n_tiles; + slice_col = slice_col_par % n_tiles; + slice_iters = k_tiles; + A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda; + C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n; + if (is_a_8bit) { + __syncthreads(); + int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x; + cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd], + threadIdx.x < prob_m); + } + } + }; + + auto init_slice = [&]() { + if (!in_part2 && !part1_mn_iters) { + in_part2 = true; + slice_col_par = (iters * blockIdx.x) / k_tiles; + slice_row = (iters * blockIdx.x) % k_tiles; + slice_col = (slice_col_par + global_mn_tiles - part2_mn_tiles) % n_tiles; + par_id = (slice_col_par + global_mn_tiles - part2_mn_tiles) / n_tiles; + A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda; + C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n; + } + if (!in_part2) { + init_part1_slice(); + } else { + init_part2_slice(); + first_init = false; + } + }; + + init_slice(); // A sizes/strides // stride of the A matrix in global memory - int a_gl_stride = lda / 8; + int a_gl_stride = lda / (is_a_8bit ? 16 : 8); // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + constexpr int a_sh_stride = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); // between subsequent accesses within a tile int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // within a shared memory tile constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // overall size of a tile @@ -496,24 +656,25 @@ __global__ void Marlin( constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + int b_gl_stride = 16 * prob_n / (pack_factor * (is_a_8bit ? 2 : 4)); + constexpr int b_sh_stride = + ((thread_n_blocks * 16) * 16 / pack_factor) / (is_a_8bit ? 2 : 4); + constexpr int b_thread_vecs = b_type.size_bits() == 4 ? 1 : 2; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_stage = + b_sh_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = + 16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8); constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) + ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -527,7 +688,7 @@ __global__ void Marlin( int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; + constexpr int tb_n_warps = thread_n_blocks / (is_a_8bit ? 2 : 4); int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Zero-points sizes/strides @@ -550,17 +711,22 @@ __global__ void Marlin( int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + a_sh_rd += 2 * ((threadIdx.x / 32) / tb_n_warps) * b_sh_wr_iters; + + int b_gl_rd; + if (threads <= b_sh_stride) { + b_gl_rd = threadIdx.x; + } else { + b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + } - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x * b_thread_vecs; auto b_sh_rd = threadIdx.x * b_thread_vecs; + b_sh_rd += b_sh_rd / b_sh_stride * (b_sh_stride * (b_sh_wr_iters - 1)); // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; int slice_k_start = tb_k * slice_row; int slice_k_finish = slice_k_start + tb_k * slice_iters; int slice_k_start_shared_fetch = slice_k_start; @@ -571,58 +737,54 @@ __global__ void Marlin( if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / - (w_type == vllm::kFE2M1f ? 2 : 1) + + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; } } auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + bool s_sh_wr_pred = threadIdx.x < s_sh_stage; // Zero-points int zp_gl_rd; if constexpr (has_zp) { if constexpr (group_blocks == -1) { zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { + } else if constexpr (group_blocks >= thread_k_blocks) { zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; } } auto zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + bool zp_sh_wr_pred = zp_sh_stage > 0 && threadIdx.x < zp_sh_stage; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - + if constexpr (is_a_8bit) { + s_sh_rd = 4 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 4); } else if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; int bias_sh_rd; if constexpr (m_block_size_8) { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 8; + bias_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; } else { - bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + bias_sh_rd = (is_a_8bit ? 4 : 8) * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; } @@ -638,12 +800,16 @@ __global__ void Marlin( if constexpr (has_zp) { if constexpr (is_zp_float) { if constexpr (group_blocks != -1) { - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; + zp_sh_rd = + 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; } + } else if (is_a_8bit) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % tb_n_warps / 2) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } else { zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + ((threadIdx.x / 32) % tb_n_warps) + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } } @@ -678,26 +844,19 @@ __global__ void Marlin( for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[i][j] = transform_a(2 * i + a_sh_rd_delta_i * j + a_sh_rd); } // Since B-accesses have non-constant stride they have to be computed at // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; constexpr int sh_b_size = stages * b_sh_stage; - int4* sh_b = sh; - int4* sh_red = sh; - + int4* sh_b = sh_new; + int4* sh_red = sh_new; constexpr int sh_size_b_red_min = (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); constexpr int sh_size_b_red_max = @@ -708,8 +867,8 @@ __global__ void Marlin( ? sh_size_b_red_max : (sh_size_b_red_min + sh_bias_size); - int4* sh_bias = sh + sh_size_b_red_min; - int4* sh_g_idx = sh + sh_b_red_bias_size; + int4* sh_bias = sh_new + sh_size_b_red_min; + int4* sh_g_idx = sh_new + sh_b_red_bias_size; int4* sh_zp = sh_g_idx + (stages * g_idx_stage); constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); @@ -723,7 +882,8 @@ __global__ void Marlin( // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; + FragC frag_c[thread_m_blocks][is_a_8bit ? 2 : 4][2]; + FragC frag_c_tmp[thread_m_blocks][is_a_8bit ? 2 : 4][2]; FragS frag_s[2][4]; // No act-order FragS frag_bias[2][4]; FragS act_frag_s[2][4][4]; // For act-order @@ -731,6 +891,24 @@ __global__ void Marlin( FragZP frag_zp; // Zero-points in fp16 FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + if constexpr (is_a_8bit) { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } + } + // Zero accumulators. auto zero_accums = [&]() { #pragma unroll @@ -788,15 +966,17 @@ __global__ void Marlin( } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } + for (int i = 0; i < (b_sh_wr_iters * b_thread_vecs); i++) { + constexpr int count = div_ceil(b_sh_stride, threads); + int b_gl_idx = + b_gl_rd + (i % count) * threads + + b_gl_stride * (i / count) * div_ceil(threads, b_sh_stride); - B_ptr[i] += b_gl_rd_delta_o; + cp_async4(&sh_b_stage[threads * i + threadIdx.x], &B[b_gl_idx]); } + b_gl_rd += b_gl_rd_delta_o; + if constexpr (has_act_order) { // Fetch g_idx thread-block portion int full_pipe = a_off; @@ -816,44 +996,24 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + // Only fetch scales if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta * s_tb_groups; } } if constexpr (has_zp && group_blocks != -1) { int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; + // Only fetch zero points if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta * zp_tb_groups; } } } @@ -891,14 +1051,14 @@ __global__ void Marlin( int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm( + ldsm( frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + &sh_b_stage[b_sh_stride * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; @@ -922,53 +1082,54 @@ __global__ void Marlin( auto fetch_scales_to_registers = [&](int k, int full_pipe) { int pipe = full_pipe % stages; + using IT1 = typename std::conditional_t; + using IT0 = typename std::conditional_t; + constexpr int group_blocks2 = div_ceil(group_blocks, is_a_8bit ? 2 : 1); if constexpr (!has_act_order) { // No act-order case if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 && dequant_skip_flop) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - reinterpret_cast(&frag_s[1])[0] = - reinterpret_cast(&frag_s[0])[0]; + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } } - } else { + } else if (group_blocks2 < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / tb_n_warps; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = - k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / group_blocks2; int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (w_type_id != vllm::kFE2M1f.id()) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast( - sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } else { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( - sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + - k % 2]; + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } else if (group_blocks >= b_sh_wr_iters) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; } } } @@ -989,18 +1150,15 @@ __global__ void Marlin( cur_k = 0; // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); + cur_k += k % b_sh_wr_iters; // Determine "position" inside the thread-block (based on warp and // thread-id) auto warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + int warp_row = warp_id / tb_n_warps; + int warp_col = warp_id % tb_n_warps; - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; + cur_k += warp_row * 16 * b_sh_wr_iters; auto th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix @@ -1055,18 +1213,16 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { // load only when starting a new slice - if (k == 0 && full_pipe == 0) { + if (k == 0 && full_pipe == 0 || is_a_8bit) { #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; } } - } else if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0 || is_a_8bit) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = @@ -1075,21 +1231,11 @@ __global__ void Marlin( } } else { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int warp_row = warp_id / tb_n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / div_ceil(group_blocks, is_a_8bit ? 2 : 1); int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1108,29 +1254,18 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_zp_stage = - sh_zp + - zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; } - } else { + } else if (group_blocks < b_sh_wr_iters || k % b_sh_wr_iters == 0) { auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero + int warp_row = warp_id / tb_n_warps; + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; int cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; @@ -1141,33 +1276,46 @@ __global__ void Marlin( } }; - auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - dequant(q, frag_b_ptr); + auto dequant_data = [&](int q, scalar_32bit_t* frag_b_ptr, int zp = 0) { + if constexpr (a_type.size_bits() != b_type.size_bits()) { + if constexpr (is_a_8bit && has_zp) { + sub_zp_and_dequant( + q, frag_b_ptr, zp); + } else { + dequant(q, frag_b_ptr); + } + } }; // Execute the actual tensor core matmul of a sub-tile. bool is_first_matmul_in_slice = true; - auto matmul = [&](int k) { + auto matmul = [&](int k, int pipe) { + if (is_a_8bit) return; int k2 = k % 2; + constexpr int g = + group_blocks > 0 ? div_ceil(group_blocks, thread_k_blocks) : 1; const bool is_new_zp = - ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == 0) || + ((group_blocks > 0) && (group_blocks < b_sh_wr_iters || k == 0)) && + (pipe % g == 0) || (group_blocks == -1 && is_first_matmul_in_slice); if constexpr (has_zp && !is_zp_float) { if (is_new_zp) { if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; int zp_quant_0, zp_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (b_type.size_bits() == 4) { zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = zp_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); zp_quant_0 = frag_qzp[k2][0]; zp_quant_1 = frag_qzp[k2][1]; } - dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); - dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, + reinterpret_cast(&frag_zp) + 2); } } if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { @@ -1177,14 +1325,14 @@ __global__ void Marlin( } } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (b_type == vllm::kFE2M1f) { int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales( - s_quant_0, reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales( - s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } // We have the m dimension as the inner loop in order to encourage overlapping @@ -1195,61 +1343,168 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type_id == vllm::kFE2M1f.id()) { + if constexpr (b_type_id == vllm::kFE2M1f.id()) { b_quant_1 = frag_b_quant[k2][0][j]; b_quant_0 = b_quant_1 << 8; - } else if constexpr (w_type.size_bits() == 4) { + } else if constexpr (b_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { - static_assert(w_type.size_bits() == 8); + static_assert(b_type.size_bits() == 8); int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; } - dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); - dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); - if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { - sub_zp(frag_b0, frag_zp[j], 0); - sub_zp(frag_b1, frag_zp[j], 1); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float && !is_a_8bit) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); } // Apply scale to frag_b0 - if constexpr (has_act_order) { + if constexpr (has_act_order && !is_a_8bit) { static_assert(group_blocks != -1); - scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); - scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && - group_blocks == -1) { + group_blocks == -1 && !is_a_8bit) { int idx = (threadIdx.x / 4) % 2; - scalar_t2 s2 = Dtype::nums2num2( + scalar_t2 s2 = Adtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); - scale_and_sub(frag_b0, s2.x, frag_zp[j].x); - scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 && + !is_a_8bit) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); - scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); - } else if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k2][j], 0); - scale(frag_b1, frag_s[k2][j], 1); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1 && !is_a_8bit) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { if constexpr (m_block_size_8) { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + mma_trans(frag_a[k2][i], frag_b0, frag_b1, + frag_c[i][j][0]); } else { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + auto matmul_a8 = [&](int k) { + int k2 = k % 2; + #pragma unroll + for (int j = 0; j < 2; j++) { + FragB frag_b[2]; + + if (is_a_8bit && b_type.size_bits() == 4 && !has_zp) { + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b)); + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2); + } else if (is_a_8bit && b_type.size_bits() == 4 && has_zp) { + int off = (threadIdx.x / 32) % 2 * 2 + j; + int zp = (frag_qzp[k2][0] >> (off * 8)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b), zp); + zp = (frag_qzp[k2][0] >> (off * 8 + 4)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2, zp); + } else { + reinterpret_cast(&frag_b)[0] = + reinterpret_cast(&frag_b_quant[k2][j])[0]; + reinterpret_cast(&frag_b)[1] = + reinterpret_cast(&frag_b_quant[k2][j])[1]; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k2][i], frag_b[0], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); + mma(frag_a[k2][i], frag_b[1], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); + } + + if constexpr (group_blocks != -1) { + if (group_blocks == 2 || k == 1) { + if constexpr (a_type == vllm::kS8) { + int2 s_vals[2]; + s_vals[0] = { + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[1]}; + s_vals[1] = { + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[1]}; + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[0])[g % 2]; + *reinterpret_cast(&frag_c[i][j][0][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][0][g]) * + scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[1])[g % 2]; + *reinterpret_cast(&frag_c[i][j][1][g]) += + *reinterpret_cast(&frag_c_tmp[i][j][1][g]) * + scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } else { + float2 s_vals[2]; + if constexpr (s_type_id != vllm::kFE8M0fnu.id()) { + static_assert(a_type.size_bits() == 16 || + s_type.size_bits() == 16); + s_vals[0] = Cdtype::num22float2(frag_s[k2][j * 2][0]); + s_vals[1] = Cdtype::num22float2(frag_s[k2][j * 2 + 1][0]); + } else { + int32_t* s_vals_int = reinterpret_cast(&s_vals[0]); + int32_t s_vals_e8m0 = + *reinterpret_cast(&frag_s[k2][j][0]); + + s_vals_int[0] = (s_vals_e8m0 & 0xFF) << 23; + s_vals_int[1] = (s_vals_e8m0 & 0xFF00) << 15; + s_vals_int[2] = (s_vals_e8m0 & 0xFF0000) << 7; + s_vals_int[3] = (s_vals_e8m0 & 0xFF000000) >> 1; + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[0])[g % 2]; + frag_c[i][j][0][g] += frag_c_tmp[i][j][0][g] * scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[1])[g % 2]; + frag_c[i][j][1][g] += frag_c_tmp[i][j][1][g] * scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } } } } @@ -1263,7 +1518,8 @@ __global__ void Marlin( constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_stride = + b_sh_stride_threads * (is_a_8bit ? 2 : 4) * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); @@ -1278,7 +1534,8 @@ __global__ void Marlin( for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { #pragma unroll - for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + for (int j = 0; j < (is_a_8bit ? 2 : 4) * 2; + j += (m_block_size_8 ? 2 : 1)) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { @@ -1287,24 +1544,26 @@ __global__ void Marlin( float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } - sh_red[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + sh_red[red_sh_wr] = reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { #pragma unroll - for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + for (int i = 0; i < (is_a_8bit ? 2 : 4) * 2; + i += (m_block_size_8 ? 2 : 1)) { float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + i][j] += c_rd[j]; } } __syncthreads(); @@ -1320,10 +1579,10 @@ __global__ void Marlin( // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; if (threadIdx.x < active_threads) { int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_o = 8 * c_gl_stride * (is_a_8bit ? 2 : 1); int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr; if constexpr (m_block_size_8) { @@ -1331,9 +1590,9 @@ __global__ void Marlin( 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; c_gl_wr += (2 * thread_n_blocks) * slice_col; } else { - c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) * (is_a_8bit ? 2 : 1) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; + c_gl_wr += (2 * thread_n_blocks) * slice_col * (is_a_8bit ? 2 : 1); } constexpr int c_sh_wr_delta = active_threads; auto c_sh_wr = threadIdx.x; @@ -1351,6 +1610,14 @@ __global__ void Marlin( &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], (threadIdx.x % 4) * 2 + i < prob_m); + } else if constexpr (is_a_8bit) { + int2* sh_red_int2 = reinterpret_cast(sh_red); + int2* c_int2 = reinterpret_cast(C); + cp_async2_ca_pred( + &sh_red_int2[c_sh_wr + c_sh_wr_delta * i], + &c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } else { cp_async4_pred( &sh_red[c_sh_wr + c_sh_wr_delta * i], @@ -1370,36 +1637,51 @@ __global__ void Marlin( (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); if (mask) { if (!first) { - int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_scalar_t* c_red_f16; + if constexpr (is_a_8bit) { + int2 tmp = + reinterpret_cast(sh_red)[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } else { + int4 tmp = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + + (i % 4) + delta] += Cdtype::num2float(c_red_f16[j]); } } if (!last) { - int4 c; + c_scalar_t c_f16[is_a_8bit ? 4 : 8]; #pragma unroll - for (int j = 0; j < 2 * 4; j++) { + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { int delta = 0; if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + c_f16[j] = Cdtype::float2num(reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + + (i % 4) + delta]); } - if constexpr (m_block_size_8) + if constexpr (m_block_size_8) { C[c_gl_wr + i * c_gl_stride + - (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; - else + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = + *reinterpret_cast(c_f16); + } else if constexpr (is_a_8bit) { + int2* c_int2 = reinterpret_cast(C); + c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)] = + *reinterpret_cast(c_f16); + } else { C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)] = c; + c_gl_wr_delta_i * (i % 2)] = *reinterpret_cast(c_f16); + } } } } @@ -1414,10 +1696,10 @@ __global__ void Marlin( constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; - constexpr int active_threads = 32 * thread_n_blocks / 4; + constexpr int active_threads = 32 * tb_n_warps; bool is_th_active = threadIdx.x < active_threads; - constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int num_floats = thread_m_blocks * (is_a_8bit ? 2 : 4) * 2 * 4; constexpr int th_size = num_floats * sizeof(float) / 16; int c_cur_offset = locks_off * c_size; @@ -1471,7 +1753,7 @@ __global__ void Marlin( } else { c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); + c_sh_wr += (is_a_8bit ? 16 : 32) * (threadIdx.x / 32); } int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + @@ -1481,47 +1763,47 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + c_scalar_t2 res = + Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && + if constexpr (!has_act_order && group_blocks == -1 && !is_a_8bit && + b_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { - scalar_t2 tmp_scale = s[0]; + c_scalar_t2 tmp_scale = s[0]; if constexpr (m_block_size_8) { - tmp_scale = Dtype::num2num2( + tmp_scale = Cdtype::num2num2( reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); } res = __hmul2(res, tmp_scale); } - if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { res = __hmul2(res, global_scale); } if (has_bias && last) { - scalar_t2 tmp_bias = b_bias[0]; + c_scalar_t2 tmp_bias = b_bias[0]; if constexpr (m_block_size_8) { - tmp_bias = Dtype::num2num2( + tmp_bias = Cdtype::num2num2( reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); } res = __hadd2(res, tmp_bias); } if constexpr (m_block_size_8) { - ((scalar_t*)sh_red)[idx] = res.x; - ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + ((c_scalar_t*)sh_red)[idx] = res.x; + ((c_scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; } else { - ((scalar_t2*)sh_red)[idx] = res; + ((c_scalar_t2*)sh_red)[idx] = res; } }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll - for (int j = 0; j < 4; j++) { + for (int j = 0; j < (is_a_8bit ? 2 : 4); j++) { if constexpr (m_block_size_8) { int wr = c_sh_wr + 16 * j; write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], @@ -1557,9 +1839,9 @@ __global__ void Marlin( i++) { if (c_gl_wr < c_gl_wr_end) { if (use_atomic_add && slice_count > 1) { - scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); - scalar_t2* sh_red_half2 = - reinterpret_cast(&sh_red[c_sh_rd]); + c_scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); + c_scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); #pragma unroll for (int a = 0; a < 4; a++) { atomicAdd(&C_half2[a], sh_red_half2[a]); @@ -1635,7 +1917,13 @@ __global__ void Marlin( wait_for_stage(); init_same_group(pipe % stages); } - matmul(k); + + if constexpr (!is_a_8bit) { + matmul(k, pipe - (k >= b_sh_wr_iters - 2 ? 1 : 0)); + } else { + static_assert(group_blocks != 0 && group_blocks != 1); + matmul_a8(k); + } } slice_iters--; if (slice_iters == 0) { @@ -1668,13 +1956,47 @@ __global__ void Marlin( // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { + if constexpr (is_a_8bit) { + float frag_a_s[2 * thread_m_blocks]; + + for (int i = 0; i < 2 * thread_m_blocks; i++) + frag_a_s[i] = sh_a_s[i * 8 + (threadIdx.x % 32) / 4]; + + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][0][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][0][g] = c_val * s_val; + } + #pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][1][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][1][g] = c_val * s_val; + } + } + } + } + cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (b_type.size_bits() == 8 || (last || use_atomic_add) || is_a_8bit) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } @@ -1692,20 +2014,27 @@ __global__ void Marlin( } if constexpr (!has_act_order && group_blocks == -1 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + (has_zp && dequant_skip_flop || !has_zp || is_a_8bit)) { + if constexpr (is_a_8bit) { cp_async_wait<0>(); __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if (threadIdx.x / 32 < tb_n_warps) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + } + } else if (b_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < tb_n_warps) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; if constexpr (m_block_size_8) { int idx = (threadIdx.x / 4) % 2; - scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + c_scalar_t2* frag_s_half2 = + reinterpret_cast(frag_s); #pragma unroll for (int i = 0; i < 8; i++) { - frag_s_half2[i] = Dtype::num2num2( - reinterpret_cast(&frag_s_half2[i])[idx]); + frag_s_half2[i] = Cdtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); } } } @@ -1715,26 +2044,48 @@ __global__ void Marlin( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && - (has_zp && dequant_skip_flop || !has_zp)) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { + if constexpr (!has_act_order && group_blocks == -1 && is_a_8bit) { + #pragma unroll + for (int j = 0; j < 2; j++) { + float2 aa[2]; + aa[0] = Cdtype::num22float2(frag_s[0][j * 2][0]); + aa[1] = Cdtype::num22float2(frag_s[0][j * 2 + 1][0]); + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[0])[g % 2]; + frag_c[i][j][0][g] *= scale; + } + + #pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[1])[g % 2]; + frag_c[i][j][1][g] *= scale; + } + } + } + } else if (!has_act_order && group_blocks == -1 && + b_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < tb_n_warps) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); if constexpr (!m_block_size_8) { - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( + scale_float( reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } @@ -1758,7 +2109,8 @@ __global__ void Marlin( cp_async_wait<0>(); __syncthreads(); reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; - reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + if constexpr (!is_a_8bit) + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; __syncthreads(); } @@ -1768,21 +2120,22 @@ __global__ void Marlin( // only the last block in a slice actually writes the result write_result(last); slice_row = 0; - slice_col_par++; - slice_col++; + if (!in_part2) { + slice_col_par += gridDim.x; + } else { + slice_col_par++; + slice_col++; + } is_first_matmul_in_slice = true; init_slice(); if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } + a_gl_rd += a_gl_rd_delta_o * slice_row; + b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col + b_gl_rd_delta_o * slice_row; bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; // Update slice k/n for scales loading @@ -1791,12 +2144,28 @@ __global__ void Marlin( slice_k_finish = slice_k_start + tb_k * slice_iters; slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } } - start_pipes(); } } diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index 1001af05ff003..c5012a8669317 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -67,9 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, std::optional const& bias); #endif -#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ - defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \ - defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120 +#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \ + (defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \ + (defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120) void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -284,8 +284,9 @@ void get_cutlass_moe_mm_data( // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ - (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ + (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k, @@ -296,7 +297,7 @@ void get_cutlass_moe_mm_data( false, "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " "CUDA device capability: ", - version_num, ". Required capability: 90 or 100"); + version_num, ". Required capability: 90, 100, or 120"); } void get_cutlass_moe_mm_problem_sizes( @@ -304,8 +305,9 @@ void get_cutlass_moe_mm_problem_sizes( torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets) { int32_t version_num = get_sm_version_num(); -#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ - (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ + (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, blockscale_offsets); @@ -315,7 +317,7 @@ void get_cutlass_moe_mm_problem_sizes( false, "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " "kernel for CUDA device capability: ", - version_num, ". Required capability: 90 or 100"); + version_num, ". Required capability: 90, 100, or 120"); } void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, @@ -328,8 +330,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ - (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ + (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, num_local_experts, padded_m, n, k); @@ -339,7 +342,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, false, "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " "for CUDA device capability: ", - version_num, ". Required capability: 90 or 100"); + version_num, ". Required capability: 90, 100, or 120"); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 14913bef13125..914227838558a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); -#ifndef USE_ROCM // Merge attn states // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // can be used to combine partial attention results (in the split-KV case) @@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor suffix_output," " Tensor suffix_lse) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); - +#ifndef USE_ROCM ops.def( "convert_vertical_slash_indexes(" " Tensor! block_count, Tensor! block_offset, " @@ -299,9 +298,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " - "Tensor? b_bias_or_none," - "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " - "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " + "Tensor? b_bias_or_none,Tensor b_scales, " + "Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, " + "Tensor? " + "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); // conditionally compiled so impl registration is in source file @@ -309,13 +309,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin repack from GPTQ. ops.def( "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " - "SymInt size_k, SymInt size_n, int num_bits) -> Tensor"); + "SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor"); // conditionally compiled so impl registrations are in source file // awq_marlin repack from AWQ. ops.def( "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " - "SymInt size_n, int num_bits) -> Tensor"); + "SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor"); + // conditionally compiled so impl registrations are in source file + + // preprocess W-int4A-fp8 weight for marlin kernel + ops.def( + "marlin_int4_fp8_preprocess(Tensor qweight, " + "Tensor? qzeros_or_none, bool inplace) -> Tensor"); // conditionally compiled so impl registrations are in source file // CUTLASS w4a8 GEMM diff --git a/docker/Dockerfile b/docker/Dockerfile index 84a1802dbe03a..006481b23cb9f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -244,9 +244,15 @@ RUN mkdir -p /tmp/deepgemm/dist && touch /tmp/deepgemm/dist/.deepgemm_skipped COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh # Install EP kernels(pplx-kernels and DeepEP) +ARG PPLX_COMMIT_HASH +ARG DEEPEP_COMMIT_HASH RUN --mount=type=cache,target=/root/.cache/uv \ export TORCH_CUDA_ARCH_LIST='9.0a 10.0a' && \ - /tmp/install_python_libraries.sh /tmp/ep_kernels_workspace wheel && \ + /tmp/install_python_libraries.sh \ + --workspace /tmp/ep_kernels_workspace \ + --mode wheel \ + ${PPLX_COMMIT_HASH:+--pplx-ref "$PPLX_COMMIT_HASH"} \ + ${DEEPEP_COMMIT_HASH:+--deepep-ref "$DEEPEP_COMMIT_HASH"} && \ find /tmp/ep_kernels_workspace/nvshmem -name '*.a' -delete # Check the size of the wheel if RUN_WHEEL_CHECK is true @@ -358,7 +364,12 @@ RUN CUDA_VERSION_DASH=$(echo $CUDA_VERSION | cut -d. -f1,2 | tr '.' '-') && \ cuda-cudart-${CUDA_VERSION_DASH} \ cuda-nvrtc-${CUDA_VERSION_DASH} \ cuda-cuobjdump-${CUDA_VERSION_DASH} \ - libcublas-${CUDA_VERSION_DASH} && \ + # https://github.com/vllm-project/vllm/issues/29590 + libcurand-dev-${CUDA_VERSION_DASH} \ + libcublas-${CUDA_VERSION_DASH} \ + # Fixes nccl_allocator requiring nccl.h at runtime + # https://github.com/vllm-project/vllm/blob/1336a1ea244fa8bfd7e72751cabbdb5b68a0c11a/vllm/distributed/device_communicators/pynccl_allocator.py#L22 + libnccl-dev && \ rm -rf /var/lib/apt/lists/* ARG PIP_INDEX_URL UV_INDEX_URL @@ -392,8 +403,8 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer pre-compiled kernel cache and binaries # https://docs.flashinfer.ai/installation.html RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system flashinfer-cubin==0.5.2 \ - && uv pip install --system flashinfer-jit-cache==0.5.2 \ + uv pip install --system flashinfer-cubin==0.5.3 \ + && uv pip install --system flashinfer-jit-cache==0.5.3 \ --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ && flashinfer show-config diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index eb3807ef0ca4e..8d55ecfba3e52 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -132,7 +132,7 @@ RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ esac; \ }; \ remove_packages_not_supported_on_aarch64 && \ - sed -i 's/^torch==.*/torch==2.8.0/g' requirements/cpu-test.in && \ + sed -i 's/^torch==.*/torch==2.9.1/g' requirements/cpu-test.in && \ sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \ sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \ uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 42466d1801cf6..1b6bdabc7a539 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -65,6 +65,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1 # ----------------------- # Test vLLM image @@ -88,10 +89,22 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace # install development dependencies (for testing) RUN cd /vllm-workspace \ - && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ && python3 -m pip install pytest-shard +# enable fast downloads from hf (for testing) +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system hf_transfer +ENV HF_HUB_ENABLE_HF_TRANSFER=1 + +# Copy in the v1 package (for python-only install test group) +COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1 + +# Source code is used in the `python_only_compile.sh` test +# We hide it inside `src/` so that this source code +# will not be imported by other tests +RUN mkdir src && mv vllm src/vllm + # ----------------------- # Final vLLM image FROM base AS final diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index df4f9b6c26e7d..a57ee728d9243 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -5,6 +5,8 @@ ARG PYTORCH_BRANCH="1c57644d" ARG PYTORCH_VISION_BRANCH="v0.23.0" ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" +ARG PYTORCH_AUDIO_BRANCH="v2.9.0" +ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git" ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG AITER_BRANCH="59bd8ff2" @@ -23,6 +25,7 @@ ENV AITER_ROCM_ARCH=gfx942;gfx950 ENV HSA_NO_SCRATCH_RECLAIM=1 ARG PYTHON_VERSION=3.12 +ENV PYTHON_VERSION=${PYTHON_VERSION} RUN mkdir -p /app WORKDIR /app @@ -45,6 +48,7 @@ RUN apt-get update -y \ && python3 --version && python3 -m pip --version RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython +RUN apt-get update && apt-get install -y libjpeg-dev libsox-dev libsox-fmt-all sox && rm -rf /var/lib/apt/lists/* FROM base AS build_triton ARG TRITON_BRANCH @@ -66,11 +70,14 @@ RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install FROM base AS build_pytorch ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH +ARG PYTORCH_AUDIO_BRANCH ARG PYTORCH_REPO ARG PYTORCH_VISION_REPO +ARG PYTORCH_AUDIO_REPO + RUN git clone ${PYTORCH_REPO} pytorch -RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \ - pip install -r requirements.txt && git submodule update --init --recursive \ +RUN cd pytorch && git checkout ${PYTORCH_BRANCH} \ + && pip install -r requirements.txt && git submodule update --init --recursive \ && python3 tools/amd_build/build_amd.py \ && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl @@ -78,8 +85,15 @@ RUN git clone ${PYTORCH_VISION_REPO} vision RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ && python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl +RUN git clone ${PYTORCH_AUDIO_REPO} audio +RUN cd audio && git checkout ${PYTORCH_AUDIO_BRANCH} \ + && git submodule update --init --recursive \ + && pip install -r requirements.txt \ + && python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ - && cp /app/vision/dist/*.whl /app/install + && cp /app/vision/dist/*.whl /app/install \ + && cp /app/audio/dist/*.whl /app/install FROM base AS build_fa ARG FA_BRANCH @@ -130,6 +144,8 @@ ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH ARG PYTORCH_REPO ARG PYTORCH_VISION_REPO +ARG PYTORCH_AUDIO_BRANCH +ARG PYTORCH_AUDIO_REPO ARG FA_BRANCH ARG FA_REPO ARG AITER_BRANCH @@ -141,7 +157,9 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ + && echo "PYTORCH_AUDIO_BRANCH: ${PYTORCH_AUDIO_BRANCH}" >> /app/versions.txt \ + && echo "PYTORCH_AUDIO_REPO: ${PYTORCH_AUDIO_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \ && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ - && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt + && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt \ No newline at end of file diff --git a/docs/.nav.yml b/docs/.nav.yml index c8bf00efb2370..aa98ad52be215 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -5,11 +5,7 @@ nav: - Getting Started: - getting_started/quickstart.md - getting_started/installation - - Examples: - - examples/README.md - - Offline Inference: examples/offline_inference - - Online Serving: examples/online_serving - - Others: examples/others + - Examples: examples - General: - usage/v1_guide.md - usage/* @@ -52,6 +48,11 @@ nav: - Plugins: - design/*plugin*.md - design/* + - Benchmarking: + - benchmarking/README.md + - benchmarking/cli.md + - benchmarking/sweeps.md + - benchmarking/dashboard.md - API Reference: - api/README.md - api/vllm diff --git a/docs/benchmarking/README.md b/docs/benchmarking/README.md new file mode 100644 index 0000000000000..238290d4762b3 --- /dev/null +++ b/docs/benchmarking/README.md @@ -0,0 +1,7 @@ +# Benchmark Suites + +vLLM provides comprehensive benchmarking tools for performance testing and evaluation: + +- **[Benchmark CLI](./cli.md)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing. +- **[Parameter Sweeps](./sweeps.md)**: Automate `vllm bench` runs for multiple configurations, useful for [optimization and tuning](../configuration/optimization.md). +- **[Performance Dashboard](./dashboard.md)**: Automated CI that publishes benchmarks on each commit. diff --git a/docs/contributing/benchmarks.md b/docs/benchmarking/cli.md similarity index 72% rename from docs/contributing/benchmarks.md rename to docs/benchmarking/cli.md index c9bc9cfe28a35..44a4c40125952 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/benchmarking/cli.md @@ -1,22 +1,10 @@ ---- -toc_depth: 4 ---- +# Benchmark CLI -# Benchmark Suites +This section guides you through running benchmark tests with the extensive datasets supported on vLLM. -vLLM provides comprehensive benchmarking tools for performance testing and evaluation: +It's a living document, updated as new features and datasets become available. -- **[Benchmark CLI](#benchmark-cli)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing -- **[Parameter sweeps](#parameter-sweeps)**: Automate `vllm bench` runs for multiple configurations -- **[Performance benchmarks](#performance-benchmarks)**: Automated CI benchmarks for development - -## Benchmark CLI - -This section guides you through running benchmark tests with the extensive -datasets supported on vLLM. It's a living document, updated as new features and datasets -become available. - -### Dataset Overview +## Dataset Overview